Эх сурвалжийг харах

update tokenizer base; add qwen2 tokenizer (#3774)

* update tokenizer base; add qwen2 tokenizer

* fix tying hint
Zhang Zelun 7 сар өмнө
parent
commit
24ab1ac00a

+ 1 - 0
paddlex/inference/models/common/tokenizer/__init__.py

@@ -15,4 +15,5 @@
 from .bert_tokenizer import BertTokenizer
 from .bert_tokenizer import BertTokenizer
 from .clip_tokenizer import CLIPTokenizer
 from .clip_tokenizer import CLIPTokenizer
 from .gpt_tokenizer import GPTTokenizer
 from .gpt_tokenizer import GPTTokenizer
+from .qwen2_tokenizer import MIXQwen2Tokenizer, Qwen2Tokenizer
 from .tokenizer_utils import PretrainedTokenizer
 from .tokenizer_utils import PretrainedTokenizer

+ 430 - 0
paddlex/inference/models/common/tokenizer/qwen2_tokenizer.py

@@ -0,0 +1,430 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import unicodedata
+from functools import lru_cache
+from typing import List, Optional, Tuple
+
+import regex as re
+
+from .tokenizer_utils import PretrainedTokenizer
+from .tokenizer_utils_base import AddedToken, TextInput
+
+VOCAB_FILES_NAMES = {
+    "vocab_file": "vocab.json",
+    "merges_file": "merges.txt",
+}
+
+__all__ = ["Qwen2Tokenizer", "MIXQwen2Tokenizer"]
+
+MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
+
+PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
+
+
+@lru_cache()
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+    characters the bpe code barfs on.
+
+    The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+    if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
+    decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
+    tables between utf-8 bytes and unicode strings.
+    """
+    bs = (
+        list(range(ord("!"), ord("~") + 1))
+        + list(range(ord("¡"), ord("¬") + 1))
+        + list(range(ord("®"), ord("ÿ") + 1))
+    )
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8 + n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+    """
+    Return set of symbol pairs in a word.
+
+    Word is represented as tuple of symbols (symbols being variable-length strings).
+    """
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+class Qwen2Tokenizer(PretrainedTokenizer):
+    """
+    Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+    Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
+    be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+    ```python
+    >>> from transformers import Qwen2Tokenizer
+
+    >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
+    >>> tokenizer("Hello world")["input_ids"]
+    [9707, 1879]
+
+    >>> tokenizer(" Hello world")["input_ids"]
+    [21927, 1879]
+    ```
+    This is expected.
+
+    You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
+
+    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+    this superclass for more information regarding those methods.
+
+    Args:
+        vocab_file (`str`):
+            Path to the vocabulary file.
+        merges_file (`str`):
+            Path to the merges file.
+        errors (`str`, *optional*, defaults to `"replace"`):
+            Paradigm to follow when decoding bytes to UTF-8. See
+            [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+        unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+            token instead.
+        bos_token (`str`, *optional*):
+            The beginning of sequence token. Not applicable for this tokenizer.
+        eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+            The end of sequence token.
+        pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+            The token used for padding, for example when batching sequences of different lengths.
+        clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+            Whether or not the model should cleanup the spaces that were added when splitting the input text during the
+            tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
+        split_special_tokens (`bool`, *optional*, defaults to `False`):
+            Whether or not the special tokens should be split during the tokenization process. The default behavior is
+            to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
+            ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
+            '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
+    """
+
+    resource_files_names = VOCAB_FILES_NAMES
+    model_input_names = ["input_ids", "attention_mask"]
+    max_model_input_sizes = MAX_MODEL_INPUT_SIZES
+
+    def __init__(
+        self,
+        vocab_file,
+        merges_file,
+        errors="replace",
+        unk_token="<|endoftext|>",
+        bos_token=None,
+        eos_token="<|endoftext|>",
+        pad_token="<|endoftext|>",
+        clean_up_tokenization_spaces=False,
+        split_special_tokens=False,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        # Qwen vocab does not contain control tokens; added tokens need to be special
+        bos_token = (
+            AddedToken(
+                bos_token, lstrip=False, rstrip=False, special=True, normalized=False
+            )
+            if isinstance(bos_token, str)
+            else bos_token
+        )
+        eos_token = (
+            AddedToken(
+                eos_token, lstrip=False, rstrip=False, special=True, normalized=False
+            )
+            if isinstance(eos_token, str)
+            else eos_token
+        )
+        unk_token = (
+            AddedToken(
+                unk_token, lstrip=False, rstrip=False, special=True, normalized=False
+            )
+            if isinstance(unk_token, str)
+            else unk_token
+        )
+        pad_token = (
+            AddedToken(
+                pad_token, lstrip=False, rstrip=False, special=True, normalized=False
+            )
+            if isinstance(pad_token, str)
+            else pad_token
+        )
+
+        with open(vocab_file, encoding="utf-8") as vocab_handle:
+            self.encoder = json.load(vocab_handle)
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.errors = errors  # how to handle errors in decoding
+        self.byte_encoder = bytes_to_unicode()
+        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+        bpe_merges = []
+        with open(merges_file, encoding="utf-8") as merges_handle:
+            for i, line in enumerate(merges_handle):
+                line = line.strip()
+                if (i == 0 and line.startswith("#version:")) or not line:
+                    continue
+                bpe_merges.append(tuple(line.split()))
+        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+        # NOTE: the cache can grow without bound and will get really large for long running processes
+        # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
+        # not a memory leak but appears as one.
+        # GPT2Tokenizer has the same problem, so let's be consistent.
+        self.cache = {}
+
+        self.pat = re.compile(PRETOKENIZE_REGEX)
+
+        self.bos_token_id = kwargs["bos_token_id"] if "bos_token_id" in kwargs else None
+        self.eos_token_id = kwargs["eos_token_id"] if "eos_token_id" in kwargs else None
+        self.unk_token_id = kwargs["unk_token_id"] if "unk_token_id" in kwargs else None
+        self.pad_token_id = kwargs["pad_token_id"] if "pad_token_id" in kwargs else None
+
+        super().__init__(
+            errors=errors,
+            bos_token=bos_token,
+            eos_token=eos_token,
+            pad_token=pad_token,
+            unk_token=unk_token,
+            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+            split_special_tokens=split_special_tokens,
+            **kwargs,
+        )
+
+    @property
+    def vocab_size(self) -> int:
+        return len(self.encoder)
+
+    def get_vocab(self):
+        return dict(self.encoder, **self.added_tokens_encoder)
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token)
+        pairs = get_pairs(word)
+
+        if not pairs:
+            return token
+
+        while True:
+            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                except ValueError:
+                    new_word.extend(word[i:])
+                    break
+                else:
+                    new_word.extend(word[i:j])
+                    i = j
+
+                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+                    new_word.append(first + second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = " ".join(word)
+        self.cache[token] = word
+        return word
+
+    def _tokenize(self, text):
+        """Tokenize a string."""
+        bpe_tokens = []
+        for token in re.findall(self.pat, text):
+            token = "".join(
+                self.byte_encoder[b] for b in token.encode("utf-8")
+            )  # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
+            bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
+        return bpe_tokens
+
+    def _convert_token_to_id(self, token):
+        """Converts a token (str) in an id using the vocab."""
+        return self.encoder.get(
+            token, self.added_tokens_encoder.get(token, len(self.encoder))
+        )
+
+    def _convert_id_to_token(self, index):
+        """Converts an index (integer) in a token (str) using the vocab."""
+        return self.decoder.get(
+            index, self.added_tokens_decoder.get(index, self.unk_token)
+        )
+
+    def convert_tokens_to_string(self, tokens):
+        """Converts a sequence of tokens (string) in a single string."""
+        text = "".join(tokens)
+        text = bytearray([self.byte_decoder[c] for c in text]).decode(
+            "utf-8", errors=self.errors
+        )
+        return text
+
+    def _decode(
+        self,
+        token_ids,
+        skip_special_tokens: bool = False,
+        clean_up_tokenization_spaces: Optional[bool] = False,
+        spaces_between_special_tokens: bool = False,
+        **kwargs,
+    ) -> str:
+        # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
+        # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
+        return super()._decode(
+            token_ids,
+            skip_special_tokens=skip_special_tokens,
+            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+            spaces_between_special_tokens=spaces_between_special_tokens,
+            **kwargs,
+        )
+
+    def save_vocabulary(
+        self, save_directory: str, filename_prefix: Optional[str] = None
+    ) -> Tuple[str]:
+        vocab_file = os.path.join(
+            save_directory,
+            (filename_prefix + "-" if filename_prefix else "")
+            + VOCAB_FILES_NAMES["vocab_file"],
+        )
+        merge_file = os.path.join(
+            save_directory,
+            (filename_prefix + "-" if filename_prefix else "")
+            + VOCAB_FILES_NAMES["merges_file"],
+        )
+
+        with open(vocab_file, "w", encoding="utf-8") as f:
+            f.write(
+                json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False)
+                + "\n"
+            )
+
+        index = 0
+        with open(merge_file, "w", encoding="utf-8") as writer:
+            writer.write("#version: 0.2\n")
+            for bpe_tokens, token_index in sorted(
+                self.bpe_ranks.items(), key=lambda kv: kv[1]
+            ):
+                if index != token_index:
+                    index = token_index
+                writer.write(" ".join(bpe_tokens) + "\n")
+                index += 1
+
+        return vocab_file, merge_file
+
+    def prepare_for_tokenization(self, text, **kwargs):
+        text = unicodedata.normalize("NFC", text)
+        return (text, kwargs)
+
+
+class MIXQwen2Tokenizer(Qwen2Tokenizer):
+    def __init__(self, *args, **kwargs):
+        super(MIXQwen2Tokenizer, self).__init__(*args, **kwargs)
+
+    def tokenize(self, text: TextInput, **kwargs) -> List[str]:
+        """
+        Converts a string in a sequence of tokens, using the tokenizer.
+
+        Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
+        (BPE/SentencePieces/WordPieces). Takes care of added tokens.
+
+        Args:
+            text (`str`):
+                The sequence to be encoded.
+            **kwargs (additional keyword arguments):
+                Passed along to the model-specific `prepare_for_tokenization` preprocessing method.
+
+        Returns:
+            `List[str]`: The list of tokens.
+        """
+
+        split_special_tokens = kwargs.pop(
+            "split_special_tokens", self.split_special_tokens
+        )
+
+        # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
+        all_special_tokens_extended = dict(
+            (str(t), t)
+            for t in self.all_special_tokens_extended
+            if isinstance(t, AddedToken)
+        )
+
+        text, kwargs = self.prepare_for_tokenization(text, **kwargs)
+
+        # TODO: should this be in the base class?
+        if hasattr(self, "do_lower_case") and self.do_lower_case:
+            # convert non-special tokens to lowercase
+            escaped_special_toks = [
+                re.escape(s_tok)
+                for s_tok in (self.unique_no_split_tokens + self.all_special_tokens)
+            ]
+            pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
+            text = re.sub(
+                pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text
+            )
+
+        if split_special_tokens:
+            no_split_token = []
+            tokens = [text]
+        else:
+            no_split_token = set(
+                self.unique_no_split_tokens
+            )  # don't split on any of the added tokens
+            # "This is something<special_token_1>  else"
+            tokens = self.tokens_trie.split(text)
+
+        # ["This is something", "<special_token_1>", "  else"]
+        for i, token in enumerate(tokens):
+            if token in no_split_token:
+                tok_extended = all_special_tokens_extended.get(token, None)
+                left = tokens[i - 1] if i > 0 else None
+                right = tokens[i + 1] if i < len(tokens) - 1 else None
+                if isinstance(tok_extended, AddedToken):
+                    if tok_extended.rstrip and right:
+                        # A bit counter-intuitive but we strip the left of the string
+                        # since tok_extended.rstrip means the special token is eating all white spaces on its right
+                        tokens[i + 1] = right.lstrip()
+                    if tok_extended.lstrip and left:
+                        tokens[i - 1] = left.rstrip()
+
+        tokenized_text = []
+        for token in tokens:
+            # Need to skip eventual empty (fully stripped) tokens
+            if not token:
+                continue
+            if token in no_split_token:
+                tokenized_text.append(token)
+            else:
+                tokenized_text.extend(self._tokenize(token))
+        return tokenized_text

+ 48 - 43
paddlex/inference/models/common/tokenizer/tokenizer_utils.py

@@ -24,7 +24,7 @@ import unicodedata
 from collections import OrderedDict
 from collections import OrderedDict
 from dataclasses import asdict, dataclass
 from dataclasses import asdict, dataclass
 from functools import lru_cache
 from functools import lru_cache
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Literal, Optional, Tuple, Union
 
 
 import numpy as np
 import numpy as np
 
 
@@ -643,7 +643,7 @@ class ChatTemplateMixin:
 
 
     def apply_chat_template(
     def apply_chat_template(
         self,
         self,
-        conversation: Union[Dict[str, str], str],
+        conversation: Union[List[List[str]], Dict[str, str], str],
         tokenize: bool = True,
         tokenize: bool = True,
         context_data: Dict[str, Any] = {},
         context_data: Dict[str, Any] = {},
         **tokenizer_kwargs,
         **tokenizer_kwargs,
@@ -651,7 +651,7 @@ class ChatTemplateMixin:
         """apply chat_template rules to conversation which should not be batched data
         """apply chat_template rules to conversation which should not be batched data
 
 
         Args:
         Args:
-            conversation (List[List[str, str]] | str): the conversation messages between user and bot
+            conversation (List[List[str]] , str): the conversation messages between user and bot
             context_data (Dict[str, Any]): the context data for chat_template.json
             context_data (Dict[str, Any]): the context data for chat_template.json
             tokenize (bool, optional): whether do tokenization. Defaults to True.
             tokenize (bool, optional): whether do tokenization. Defaults to True.
 
 
@@ -679,7 +679,7 @@ class ChatTemplateMixin:
 
 
     def _apply_chat_template_paddle(
     def _apply_chat_template_paddle(
         self,
         self,
-        conversation: Union[List[Dict[str, str]], str],
+        conversation: Union[List[List[str]], str],
         context_data: Dict[str, Any] = {},
         context_data: Dict[str, Any] = {},
     ):
     ):
         context_data = self.chat_template._init_context_data(context_data)
         context_data = self.chat_template._init_context_data(context_data)
@@ -697,7 +697,7 @@ class ChatTemplateMixin:
 
 
     def _apply_chat_template(
     def _apply_chat_template(
         self,
         self,
-        conversation: Union[Dict[str, str], str],
+        conversation: Union[List[List[str]], Dict[str, str], str],
         add_generation_prompt=True,
         add_generation_prompt=True,
     ):
     ):
         if isinstance(conversation, str):
         if isinstance(conversation, str):
@@ -722,7 +722,7 @@ class ChatTemplateMixin:
 
 
     def encode_chat_inputs(
     def encode_chat_inputs(
         self,
         self,
-        conversations: List[Dict[str, str]],
+        conversations: List[List[str]],
         context_data: Dict[str, Any] = {},
         context_data: Dict[str, Any] = {},
         **kwargs,
         **kwargs,
     ):
     ):
@@ -731,7 +731,7 @@ class ChatTemplateMixin:
         Turn t: sep + bot + query             bot + eos
         Turn t: sep + bot + query             bot + eos
 
 
         Args:
         Args:
-            conversation (List[Dict[str, str]]): the conversation of data
+            conversation (List[List[str]]): the conversation of data
             context_data (Dict[str, Any]): the context data of conversation
             context_data (Dict[str, Any]): the context data of conversation
 
 
         Returns:
         Returns:
@@ -751,7 +751,7 @@ class ChatTemplateMixin:
         return query
         return query
 
 
     def _encode_chat_inputs_paddle(
     def _encode_chat_inputs_paddle(
-        self, conversations: List[Dict[str, str]], context_data: Dict[str, Any] = {}
+        self, conversations: List[List[str]], context_data: Dict[str, Any] = {}
     ):
     ):
         context_data = self.chat_template._init_context_data(context_data)
         context_data = self.chat_template._init_context_data(context_data)
         # encode system
         # encode system
@@ -781,7 +781,7 @@ class ChatTemplateMixin:
 
 
     def _encode_chat_inputs(
     def _encode_chat_inputs(
         self,
         self,
-        conversations: List[Dict[str, str]],
+        conversations: List[List[str]],
         context_data: Dict[str, Any] = {},
         context_data: Dict[str, Any] = {},
         system: str = None,
         system: str = None,
         add_generation_prompt=True,
         add_generation_prompt=True,
@@ -826,7 +826,9 @@ class ChatTemplateMixin:
             ans.append(ans_roundi)
             ans.append(ans_roundi)
 
 
         non_learnable_parts = self._extract_non_learnable_parts(origin_msg, ans)
         non_learnable_parts = self._extract_non_learnable_parts(origin_msg, ans)
-        assert len(non_learnable_parts) == len(ans)
+        assert len(non_learnable_parts) == len(
+            ans
+        ), f"Get non_learnable_parts len: {len(non_learnable_parts)}, but ans len: {len(ans)}."
 
 
         conversation_ids = []
         conversation_ids = []
         for i in range(len(non_learnable_parts)):
         for i in range(len(non_learnable_parts)):
@@ -895,11 +897,11 @@ class ChatTemplateMixin:
         tokenizer.init_chat_template(chat_template_file)
         tokenizer.init_chat_template(chat_template_file)
         return tokenizer
         return tokenizer
 
 
-    def init_chat_template(self, chat_template: Union[str, Dict]):
+    def init_chat_template(self, chat_template: Union[str, dict]):
         """init chat_tempalte by file_path or template dict data
         """init chat_tempalte by file_path or template dict data
 
 
         Args:
         Args:
-            chat_template (str | dict): file_path or template dict data
+            chat_template (str, dict): file_path or template dict data
         """
         """
         if isinstance(chat_template, str):
         if isinstance(chat_template, str):
             if not os.path.exists(chat_template):
             if not os.path.exists(chat_template):
@@ -995,8 +997,12 @@ class PretrainedTokenizer(
         init_dict.pop("self", None)
         init_dict.pop("self", None)
         super(PretrainedTokenizer, self).__init__(**init_dict)
         super(PretrainedTokenizer, self).__init__(**init_dict)
 
 
-        self.added_tokens_encoder: Dict[str, int] = {}
-        self.added_tokens_decoder: Dict[int, str] = {}
+        self.added_tokens_decoder: Dict[int, AddedToken] = {}
+        self.added_tokens_decoder.update(kwargs.pop("added_tokens_decoder", {}))
+        self.added_tokens_encoder: Dict[str, int] = {
+            k.content: v for v, k in self.added_tokens_decoder.items()
+        }
+
         self.unique_no_split_tokens: List[str] = []
         self.unique_no_split_tokens: List[str] = []
         self.tokens_trie = Trie()
         self.tokens_trie = Trie()
 
 
@@ -1094,6 +1100,7 @@ class PretrainedTokenizer(
                 and self.convert_tokens_to_ids(token)
                 and self.convert_tokens_to_ids(token)
                 == self.convert_tokens_to_ids(self.unk_token)
                 == self.convert_tokens_to_ids(self.unk_token)
                 and token not in tokens_to_add
                 and token not in tokens_to_add
+                and token not in self.added_tokens_encoder.keys()
             ):
             ):
                 tokens_to_add.append(token)
                 tokens_to_add.append(token)
                 if self.verbose:
                 if self.verbose:
@@ -1182,6 +1189,11 @@ class PretrainedTokenizer(
         Returns:
         Returns:
             `List[str]`: The list of tokens.
             `List[str]`: The list of tokens.
         """
         """
+
+        split_special_tokens = kwargs.pop(
+            "split_special_tokens", self.split_special_tokens
+        )
+
         # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
         # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
         all_special_tokens_extended = dict(
         all_special_tokens_extended = dict(
             (str(t), t)
             (str(t), t)
@@ -1203,8 +1215,15 @@ class PretrainedTokenizer(
                 pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text
                 pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text
             )
             )
 
 
-        no_split_token = set(self.unique_no_split_tokens)
-        tokens = self.tokens_trie.split(text)
+        if split_special_tokens:
+            no_split_token = []
+            tokens = [text]
+        else:
+            no_split_token = set(
+                self.unique_no_split_tokens
+            )  # don't split on any of the added tokens
+            # "This is something<special_token_1>  else"
+            tokens = self.tokens_trie.split(text)
 
 
         # ["This is something", "<special_token_1>", "  else"]
         # ["This is something", "<special_token_1>", "  else"]
         for i, token in enumerate(tokens):
         for i, token in enumerate(tokens):
@@ -1289,7 +1308,9 @@ class PretrainedTokenizer(
     def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
     def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
         if isinstance(ids, int):
         if isinstance(ids, int):
             if ids in self.added_tokens_decoder:
             if ids in self.added_tokens_decoder:
-                return self.added_tokens_decoder[ids]
+                token = self.added_tokens_decoder[ids]
+                token = token.content if isinstance(token, AddedToken) else token
+                return token
             else:
             else:
                 return self._convert_id_to_token(ids)
                 return self._convert_id_to_token(ids)
         tokens = []
         tokens = []
@@ -1298,7 +1319,9 @@ class PretrainedTokenizer(
             if skip_special_tokens and index in self.all_special_ids:
             if skip_special_tokens and index in self.all_special_ids:
                 continue
                 continue
             if index in self.added_tokens_decoder:
             if index in self.added_tokens_decoder:
-                tokens.append(self.added_tokens_decoder[index])
+                token = self.added_tokens_decoder[index]
+                token = token.content if isinstance(token, AddedToken) else token
+                tokens.append(token)
             else:
             else:
                 tokens.append(self._convert_id_to_token(index))
                 tokens.append(self._convert_id_to_token(index))
         return tokens
         return tokens
@@ -1430,6 +1453,7 @@ class PretrainedTokenizer(
         stride: int = 0,
         stride: int = 0,
         is_split_into_words: bool = False,
         is_split_into_words: bool = False,
         pad_to_multiple_of: Optional[int] = None,
         pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[Literal["right", "left"]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_position_ids: Optional[bool] = None,
         return_position_ids: Optional[bool] = None,
         return_token_type_ids: Optional[bool] = None,
         return_token_type_ids: Optional[bool] = None,
@@ -1494,6 +1518,7 @@ class PretrainedTokenizer(
             max_length=max_length,
             max_length=max_length,
             stride=stride,
             stride=stride,
             pad_to_multiple_of=pad_to_multiple_of,
             pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
             return_tensors=return_tensors,
             return_tensors=return_tensors,
             prepend_batch_axis=True,
             prepend_batch_axis=True,
             return_position_ids=return_position_ids,
             return_position_ids=return_position_ids,
@@ -1524,6 +1549,7 @@ class PretrainedTokenizer(
         stride: int = 0,
         stride: int = 0,
         is_split_into_words: bool = False,
         is_split_into_words: bool = False,
         pad_to_multiple_of: Optional[int] = None,
         pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[Literal["right", "left"]] = None,
         return_position_ids: Optional[bool] = None,
         return_position_ids: Optional[bool] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_token_type_ids: Optional[bool] = None,
         return_token_type_ids: Optional[bool] = None,
@@ -1609,6 +1635,7 @@ class PretrainedTokenizer(
             max_length=max_length,
             max_length=max_length,
             stride=stride,
             stride=stride,
             pad_to_multiple_of=pad_to_multiple_of,
             pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
             return_position_ids=return_position_ids,
             return_position_ids=return_position_ids,
             return_attention_mask=return_attention_mask,
             return_attention_mask=return_attention_mask,
             return_token_type_ids=return_token_type_ids,
             return_token_type_ids=return_token_type_ids,
@@ -1633,6 +1660,7 @@ class PretrainedTokenizer(
         max_length: Optional[int] = None,
         max_length: Optional[int] = None,
         stride: int = 0,
         stride: int = 0,
         pad_to_multiple_of: Optional[int] = None,
         pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[Literal["right", "left"]] = None,
         return_position_ids: Optional[bool] = None,
         return_position_ids: Optional[bool] = None,
         return_tensors: Optional[str] = None,
         return_tensors: Optional[str] = None,
         return_token_type_ids: Optional[bool] = None,
         return_token_type_ids: Optional[bool] = None,
@@ -1761,6 +1789,7 @@ class PretrainedTokenizer(
                     max_length=max_length,
                     max_length=max_length,
                     stride=stride,
                     stride=stride,
                     pad_to_multiple_of=None,  # we pad in batch afterward
                     pad_to_multiple_of=None,  # we pad in batch afterward
+                    padding_side=padding_side,  # we pad in batch afterward
                     return_position_ids=return_position_ids,  # we pad in batch afterward
                     return_position_ids=return_position_ids,  # we pad in batch afterward
                     return_attention_mask=False,  # we pad in batch afterward
                     return_attention_mask=False,  # we pad in batch afterward
                     return_token_type_ids=return_token_type_ids,
                     return_token_type_ids=return_token_type_ids,
@@ -1783,6 +1812,7 @@ class PretrainedTokenizer(
             padding=padding_strategy.value,
             padding=padding_strategy.value,
             max_length=max_length,
             max_length=max_length,
             pad_to_multiple_of=pad_to_multiple_of,
             pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
             return_attention_mask=return_attention_mask,
             return_attention_mask=return_attention_mask,
         )
         )
         if return_dict:
         if return_dict:
@@ -2031,31 +2061,6 @@ class PretrainedTokenizer(
         else:
         else:
             return text
             return text
 
 
-    def decode_token(
-        self,
-        all_input_ids: List[int],
-        prefix_offset: int = 0,
-        read_offset: int = 0,
-    ) -> Tuple[str, int, int]:
-        """tokenizer decoding for the streaming generation use case. This method can be overrided for tokenizer that doesn't follow this API"""
-        # The prefix text is necessary only to defeat cleanup algorithms in the decode
-        # which decide to add a space or not depending on the surrounding ids.
-        prefix_text = self.decode(
-            all_input_ids[prefix_offset:read_offset], skip_special_tokens=False
-        )
-        new_text = self.decode(all_input_ids[prefix_offset:], skip_special_tokens=False)
-
-        if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
-            # utf-8 char at the end means it's a potential unfinished byte sequence
-            # from byte fallback tokenization.
-            # If it's in the middle, it's probably a real invalid id generated
-            # by the model
-            prefix_index = new_text.index(prefix_text)
-            new_text = new_text[prefix_index + len(prefix_text) :]
-            return new_text, read_offset, len(all_input_ids)
-        else:
-            return "", prefix_offset, read_offset
-
 
 
 def _is_control(char):
 def _is_control(char):
     """Checks whether `chars` is a control character."""
     """Checks whether `chars` is a control character."""

+ 329 - 114
paddlex/inference/models/common/tokenizer/tokenizer_utils_base.py

@@ -13,14 +13,25 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import copy
 import copy
+import inspect
 import io
 import io
 import json
 import json
 import os
 import os
 import warnings
 import warnings
-from collections import OrderedDict, UserDict
+from collections import UserDict
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
 from enum import Enum
 from enum import Enum
-from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
+from typing import (
+    Any,
+    Dict,
+    List,
+    Literal,
+    NamedTuple,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+)
 
 
 import numpy as np
 import numpy as np
 
 
@@ -42,7 +53,6 @@ __all__ = [
 
 
 TOKENIZER_CONFIG_NAME = "tokenizer_config.json"
 TOKENIZER_CONFIG_NAME = "tokenizer_config.json"
 CHAT_TEMPLATE_CONFIG_NAME = "chat_template.json"
 CHAT_TEMPLATE_CONFIG_NAME = "chat_template.json"
-CHAT_TEMPLATE_CONFIG_NAME = "chat_template.json"
 
 
 VERY_LARGE_INTEGER = int(
 VERY_LARGE_INTEGER = int(
     1e30
     1e30
@@ -287,10 +297,6 @@ class BatchEncoding(UserDict):
     def items(self):
     def items(self):
         return self.data.items()
         return self.data.items()
 
 
-    # After this point:
-    # Extended properties and methods only available for fast tokenizers
-    # not yet supported
-
     @property
     @property
     def encodings(self) -> Optional[List[FastEncoding]]:
     def encodings(self) -> Optional[List[FastEncoding]]:
         """
         """
@@ -850,15 +856,17 @@ class SpecialTokensMixin:
         return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
         return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
 
 
     def add_special_tokens(
     def add_special_tokens(
-        self, special_tokens_dict: Dict[str, Union[str, AddedToken]]
+        self,
+        special_tokens_dict: Dict[str, Union[str, AddedToken]],
+        replace_additional_special_tokens=True,
     ) -> int:
     ) -> int:
         """
         """
         Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
         Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
         special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
         special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
         current vocabulary).
         current vocabulary).
 
 
-        Note,None When adding new tokens to the vocabulary, you should make sure to also resize the token embedding
-        matrix of the model so that its embedding matrix matches the tokenizer.
+        When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of the
+        model so that its embedding matrix matches the tokenizer.
 
 
         In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method.
         In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method.
 
 
@@ -879,6 +887,13 @@ class SpecialTokensMixin:
 
 
                 Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer
                 Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer
                 assign the index of the `unk_token` to them).
                 assign the index of the `unk_token` to them).
+            replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`):
+                If `True`, the existing list of additional special tokens will be replaced by the list provided in
+                `special_tokens_dict`. Otherwise, `self._additional_special_tokens` is just extended. In the former
+                case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged
+                as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the
+                `added_tokens_encoder` and `added_tokens_decoder`. This means that the previous
+                `additional_special_tokens` are still added tokens, and will not be split by the model.
 
 
         Returns:
         Returns:
             `int`: Number of tokens added to the vocabulary.
             `int`: Number of tokens added to the vocabulary.
@@ -902,7 +917,7 @@ class SpecialTokensMixin:
         if not special_tokens_dict:
         if not special_tokens_dict:
             return 0
             return 0
 
 
-        added_tokens = 0
+        added_tokens = []
         for key, value in special_tokens_dict.items():
         for key, value in special_tokens_dict.items():
             assert (
             assert (
                 key in self.SPECIAL_TOKENS_ATTRIBUTES
                 key in self.SPECIAL_TOKENS_ATTRIBUTES
@@ -910,19 +925,37 @@ class SpecialTokensMixin:
 
 
             if self.verbose:
             if self.verbose:
                 logging.info(f"Assigning {value} to the {key} key of the tokenizer")
                 logging.info(f"Assigning {value} to the {key} key of the tokenizer")
-            setattr(self, key, value)
 
 
             if key == "additional_special_tokens":
             if key == "additional_special_tokens":
                 assert isinstance(value, (list, tuple)) and all(
                 assert isinstance(value, (list, tuple)) and all(
                     isinstance(t, (str, AddedToken)) for t in value
                     isinstance(t, (str, AddedToken)) for t in value
                 ), f"Tokens {value} for key {key} should all be str or AddedToken instances"
                 ), f"Tokens {value} for key {key} should all be str or AddedToken instances"
-                added_tokens += self.add_tokens(value, special_tokens=True)
+
+                to_add = []
+                for token in value:
+                    if (
+                        not replace_additional_special_tokens
+                        and str(token) in self.additional_special_tokens
+                    ):
+                        continue
+                    to_add.append(token)
+                if replace_additional_special_tokens and len(to_add) > 0:
+                    setattr(self, key, list(to_add))
+                else:
+                    self._additional_special_tokens.extend(to_add)
+                added_tokens += to_add
+
             else:
             else:
-                assert isinstance(
-                    value, (str, AddedToken)
-                ), f"Token {value} for key {key} should be a str or an AddedToken instance"
-                added_tokens += self.add_tokens([value], special_tokens=True)
+                if not isinstance(value, (str, AddedToken)):
+                    raise ValueError(
+                        f"Token {value} for key {key} should be a str or an AddedToken instance"
+                    )
+                setattr(self, key, value)
+                if value not in added_tokens:
+                    added_tokens.append(value)
 
 
+        # if we are adding tokens that were not part of the vocab, we ought to add them
+        added_tokens = self.add_tokens(added_tokens, special_tokens=True)
         return added_tokens
         return added_tokens
 
 
     def add_tokens(
     def add_tokens(
@@ -972,6 +1005,11 @@ class SpecialTokensMixin:
 
 
         return self._add_tokens(new_tokens, special_tokens=special_tokens)
         return self._add_tokens(new_tokens, special_tokens=special_tokens)
 
 
+    @classmethod
+    def _add_extra_special_tokens(cls, extra_sp_token: Union[str, AddedToken]):
+        if extra_sp_token not in cls.SPECIAL_TOKENS_ATTRIBUTES:
+            cls.SPECIAL_TOKENS_ATTRIBUTES.append(extra_sp_token)
+
     def _add_tokens(
     def _add_tokens(
         self,
         self,
         new_tokens: Union[List[str], List[AddedToken]],
         new_tokens: Union[List[str], List[AddedToken]],
@@ -1238,7 +1276,13 @@ class SpecialTokensMixin:
         """
         """
         set_attr = {}
         set_attr = {}
         for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
         for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
-            attr_value = getattr(self, "_" + attr)
+            try:
+                attr_value = getattr(self, "_" + attr)
+            except:
+                try:
+                    attr_value = getattr(self, attr)
+                except:
+                    continue
             if attr_value:
             if attr_value:
                 set_attr[attr] = (
                 set_attr[attr] = (
                     type(attr_value)(
                     type(attr_value)(
@@ -1262,7 +1306,13 @@ class SpecialTokensMixin:
         """
         """
         set_attr = {}
         set_attr = {}
         for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
         for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
-            attr_value = getattr(self, "_" + attr, None)
+            try:
+                attr_value = getattr(self, "_" + attr)
+            except:
+                try:
+                    attr_value = getattr(self, attr)
+                except:
+                    continue
             if attr_value:
             if attr_value:
                 set_attr[attr] = attr_value
                 set_attr[attr] = attr_value
         return set_attr
         return set_attr
@@ -1286,16 +1336,16 @@ class SpecialTokensMixin:
         Don't convert tokens of `AddedToken` type to string so they can be used to control more finely how
         Don't convert tokens of `AddedToken` type to string so they can be used to control more finely how
         special tokens are tokenized.
         special tokens are tokenized.
         """
         """
-        all_toks = []
-        set_attr = self.special_tokens_map_extended
-        for attr_value in set_attr.values():
-            all_toks = all_toks + (
-                list(attr_value)
-                if isinstance(attr_value, (list, tuple))
-                else [attr_value]
-            )
-        all_toks = list(OrderedDict.fromkeys(all_toks))
-        return all_toks
+        all_tokens = []
+        seen = set()
+        for value in self.special_tokens_map_extended.values():
+            if isinstance(value, (list, tuple)):
+                tokens_to_add = [token for token in value if str(token) not in seen]
+            else:
+                tokens_to_add = [value] if str(value) not in seen else []
+            seen.update(map(str, tokens_to_add))
+            all_tokens.extend(tokens_to_add)
+        return all_tokens
 
 
     @property
     @property
     def all_special_ids(self) -> List[int]:
     def all_special_ids(self) -> List[int]:
@@ -1419,6 +1469,12 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
 
 
         self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
         self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
 
 
+        self.clean_up_tokenization_spaces = kwargs.pop(
+            "clean_up_tokenization_spaces", False
+        )
+
+        self.split_special_tokens = kwargs.pop("split_special_tokens", False)
+
         self.deprecation_warnings = (
         self.deprecation_warnings = (
             {}
             {}
         )  # Use to store when we have already noticed a deprecation warning (avoid overlogging).
         )  # Use to store when we have already noticed a deprecation warning (avoid overlogging).
@@ -1462,10 +1518,10 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         """
         """
         Private method to put the tokenizer in input mode (when it has different modes for input/outputs)
         Private method to put the tokenizer in input mode (when it has different modes for input/outputs)
         """
         """
+        pass
 
 
     @max_len_sentences_pair.setter
     @max_len_sentences_pair.setter
     def max_len_sentences_pair(self, value) -> int:
     def max_len_sentences_pair(self, value) -> int:
-        # For backward compatibility, allow to try to setup 'max_len_sentences_pair'.
         if (
         if (
             value == self.model_max_length - self.num_special_tokens_to_add(pair=True)
             value == self.model_max_length - self.num_special_tokens_to_add(pair=True)
             and self.verbose
             and self.verbose
@@ -1487,10 +1543,15 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         self._processor_class = processor_class
         self._processor_class = processor_class
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
+        added_tokens_decoder_rep = "\n\t".join(
+            [f"{k}: {v.__repr__()}," for k, v in self.added_tokens_decoder.items()]
+        )
         return (
         return (
-            f"{'PretrainedTokenizer'}(name_or_path='{self.name_or_path}', "
-            f"vocab_size={self.vocab_size}, model_max_len={self.model_max_length}, "
-            f"padding_side='{self.padding_side}', truncation_side='{self.truncation_side}', special_tokens={self.special_tokens_map_extended})"
+            f"{self.__class__.__name__}(name_or_path='{self.name_or_path}',"
+            f" vocab_size={self.vocab_size}, model_max_length={self.model_max_length}, is_fast={self.is_fast},"
+            f" padding_side='{self.padding_side}', truncation_side='{self.truncation_side}',"
+            f" special_tokens={self.special_tokens_map}, clean_up_tokenization_spaces={self.clean_up_tokenization_spaces}), "
+            " added_tokens_decoder={\n\t" + added_tokens_decoder_rep + "\n}"
         )
         )
 
 
     def get_vocab(self) -> Dict[str, int]:
     def get_vocab(self) -> Dict[str, int]:
@@ -1546,17 +1607,13 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                 # Load from local directory path
                 # Load from local directory path
                 tokenizer = BertTokenizer.from_pretrained('./my_bert/')
                 tokenizer = BertTokenizer.from_pretrained('./my_bert/')
         """
         """
-
-        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
         cache_dir = kwargs.pop("cache_dir", None)
         cache_dir = kwargs.pop("cache_dir", None)
         from_hf_hub = kwargs.pop("from_hf_hub", False)
         from_hf_hub = kwargs.pop("from_hf_hub", False)
         from_aistudio = kwargs.pop("from_aistudio", False)
         from_aistudio = kwargs.pop("from_aistudio", False)
         subfolder = kwargs.pop("subfolder", "")
         subfolder = kwargs.pop("subfolder", "")
         return_tokenizer_file_dir = kwargs.pop("return_tokenizer_file_dir", False)
         return_tokenizer_file_dir = kwargs.pop("return_tokenizer_file_dir", False)
 
 
-        if subfolder is None:
-            subfolder = ""
-
+        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
         vocab_files = {}
         vocab_files = {}
         init_configuration = {}
         init_configuration = {}
 
 
@@ -1567,8 +1624,13 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
             "chat_template_file": CHAT_TEMPLATE_CONFIG_NAME,
             "chat_template_file": CHAT_TEMPLATE_CONFIG_NAME,
         }
         }
 
 
+        if hasattr(cls, "vocab_files_names") and len(cls.resource_files_names) == 0:
+            cls.resource_files_names = copy.deepcopy(cls.vocab_files_names)
+            logging.error(
+                "The attribute 'vocab_files_names' is deprecated. Please use 'resource_files_names' instead.",
+                DeprecationWarning,
+            )
         vocab_files_target = {**cls.resource_files_names, **additional_files_names}
         vocab_files_target = {**cls.resource_files_names, **additional_files_names}
-
         # From HF Hub or AI Studio
         # From HF Hub or AI Studio
         if from_hf_hub or from_aistudio:
         if from_hf_hub or from_aistudio:
             # Only include the necessary resource files specified by the tokenizer cls
             # Only include the necessary resource files specified by the tokenizer cls
@@ -1596,29 +1658,58 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
             # Assuming from community-contributed pretrained models
             # Assuming from community-contributed pretrained models
             for file_id, file_name in vocab_files_target.items():
             for file_id, file_name in vocab_files_target.items():
                 vocab_files[file_id] = file_name
                 vocab_files[file_id] = file_name
-
         resolved_vocab_files = {}
         resolved_vocab_files = {}
         for file_id, file_path in vocab_files.items():
         for file_id, file_path in vocab_files.items():
-            if file_path is None or os.path.isfile(file_path):
-                resolved_vocab_files[file_id] = file_path
-                continue
-            else:
-                logging.warnings("need to download tokenizer, but not support yet.")
-            # tokenizer download not support yet
-            # resolved_vocab_files[file_id] = resolve_file_path(
-            #     pretrained_model_name_or_path,
-            #     [file_path],
-            #     subfolder,
-            #     cache_dir=cache_dir,
-            #     from_aistudio=from_aistudio,
-            #     from_hf_hub=from_hf_hub,
-            # )
+            # adapt to PaddleX
+            resolved_vocab_files[file_id] = file_path
 
 
         for file_id, file_path in resolved_vocab_files.items():
         for file_id, file_path in resolved_vocab_files.items():
             if resolved_vocab_files[file_id] is not None:
             if resolved_vocab_files[file_id] is not None:
                 cache_dir = os.path.dirname(resolved_vocab_files[file_id])
                 cache_dir = os.path.dirname(resolved_vocab_files[file_id])
                 break
                 break
+        return cls._from_pretrained(
+            resolved_vocab_files,
+            pretrained_model_name_or_path,
+            init_configuration,
+            *args,
+            cache_dir=cache_dir,
+            return_tokenizer_file_dir=return_tokenizer_file_dir,
+            from_hf_hub=from_hf_hub,
+            **kwargs,
+        )
 
 
+    @classmethod
+    def _from_pretrained(
+        cls,
+        resolved_vocab_files,
+        pretrained_model_name_or_path,
+        init_configuration,
+        *init_inputs,
+        cache_dir=None,
+        return_tokenizer_file_dir=False,
+        from_hf_hub=False,
+        **kwargs,
+    ):
+        if cls.__name__.endswith("Fast"):
+            from_slow = kwargs.get("from_slow", False)
+        else:
+            from_slow = kwargs.get("from_slow", True)
+        has_tokenizer_file = (
+            resolved_vocab_files.get("tokenizer_file", None) is not None
+        )
+        if (
+            from_slow or not has_tokenizer_file
+        ) and cls.slow_tokenizer_class is not None:
+            slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained(
+                copy.deepcopy(resolved_vocab_files),
+                pretrained_model_name_or_path,
+                copy.deepcopy(init_configuration),
+                *init_inputs,
+                cache_dir=cache_dir,
+                **(copy.deepcopy(kwargs)),
+            )
+        else:
+            slow_tokenizer = None
         tokenizer_config_file_dir_list = set()
         tokenizer_config_file_dir_list = set()
         for k, v in resolved_vocab_files.items():
         for k, v in resolved_vocab_files.items():
             if v is not None and os.path.isfile(v):
             if v is not None and os.path.isfile(v):
@@ -1628,8 +1719,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         assert (
         assert (
             len(tokenizer_config_file_dir_list) > 0
             len(tokenizer_config_file_dir_list) > 0
         ), "All tokenizer files should be in the same directory."
         ), "All tokenizer files should be in the same directory."
-        # Prepare tokenizer initialization kwargs
-        # Did we saved some inputs and kwargs to reload ?
+
         has_tokenizer_file = (
         has_tokenizer_file = (
             resolved_vocab_files.get("tokenizer_file", None) is not None
             resolved_vocab_files.get("tokenizer_file", None) is not None
         )
         )
@@ -1637,15 +1727,34 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         if tokenizer_config_file is not None:
         if tokenizer_config_file is not None:
             with io.open(tokenizer_config_file, encoding="utf-8") as f:
             with io.open(tokenizer_config_file, encoding="utf-8") as f:
                 init_kwargs = json.load(f)
                 init_kwargs = json.load(f)
+            init_kwargs.pop("tokenizer_class", None)
         else:
         else:
             init_kwargs = init_configuration
             init_kwargs = init_configuration
 
 
-        # position args are stored in kwargs, maybe better not include
-        init_args = init_kwargs.pop("init_args", ())
+        if slow_tokenizer is not None:
+            init_kwargs["__slow_tokenizer"] = slow_tokenizer
+        init_kwargs["name_or_path"] = pretrained_model_name_or_path
+        init_kwargs["from_slow"] = from_slow
+
+        pass_added_tokens_file = False
+        added_tokens_decoder: Dict[int, AddedToken] = {}
+        if "added_tokens_decoder" in init_kwargs:
+            for idx, token in init_kwargs["added_tokens_decoder"].items():
+                if isinstance(token, dict):
+                    token = AddedToken(**token)
+                if isinstance(token, AddedToken):
+                    added_tokens_decoder[int(idx)] = token
+                else:
+                    raise ValueError(
+                        f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance"
+                    )
+            init_kwargs["added_tokens_decoder"] = (
+                added_tokens_decoder  # NOTE tokenizer_config.json下, 注册的`added_tokens_decoder`被解析成字典
+            )
+            pass_added_tokens_file = True
+
         init_kwargs.pop("init_class", None)
         init_kwargs.pop("init_class", None)
 
 
-        # Update with newly provided args and kwargs
-        init_args = init_args if not args else args
         init_kwargs.update(kwargs)
         init_kwargs.update(kwargs)
 
 
         def convert_added_tokens(obj):
         def convert_added_tokens(obj):
@@ -1663,10 +1772,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
             return obj
             return obj
 
 
         init_kwargs = convert_added_tokens(init_kwargs)
         init_kwargs = convert_added_tokens(init_kwargs)
-        # Set max length if needed
         if pretrained_model_name_or_path in cls.max_model_input_sizes:
         if pretrained_model_name_or_path in cls.max_model_input_sizes:
-            # if we're using a pretrained model, ensure the tokenizer
-            # wont index sequences longer than the number of positional embeddings
             model_max_length = cls.max_model_input_sizes[pretrained_model_name_or_path]
             model_max_length = cls.max_model_input_sizes[pretrained_model_name_or_path]
             if model_max_length is not None and isinstance(
             if model_max_length is not None and isinstance(
                 model_max_length, (int, float)
                 model_max_length, (int, float)
@@ -1675,32 +1781,28 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                     init_kwargs.get("model_max_length", int(1e30)), model_max_length
                     init_kwargs.get("model_max_length", int(1e30)), model_max_length
                 )
                 )
 
 
-        added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
-        # Merge resolved_vocab_files arguments in init_kwargs if not including.
-        # Maybe need more ways to load resources.
         for args_name, file_path in resolved_vocab_files.items():
         for args_name, file_path in resolved_vocab_files.items():
-            # when `pretrained_model_name_or_path` is a pretrained model name,
-            # use pretrained_init_configuration as `init_kwargs` to init which
-            # does not include the vocab file in it, thus add vocab file into
-            # args.
-            if args_name not in init_kwargs:
+            if args_name not in init_kwargs or init_kwargs[args_name] is None:
                 init_kwargs[args_name] = file_path
                 init_kwargs[args_name] = file_path
-            # when `pretrained_model_name_or_path` is a pretrained model dir,
-            # use tokenizer_config_file.json as `init_kwargs` to init which
-            # does include a vocab file path in it. However, if the vocab file
-            # path included in json does not exist, such as was deleted, to make
-            # it still work, use the vocab file under this dir.
             elif not os.path.isfile(init_kwargs[args_name] or "") and os.path.isfile(
             elif not os.path.isfile(init_kwargs[args_name] or "") and os.path.isfile(
                 file_path
                 file_path
             ):
             ):
                 init_kwargs[args_name] = file_path
                 init_kwargs[args_name] = file_path
 
 
-        # TODO(zhoushunjie): It's not supportted to load tokenizer.json of hf so far.
         if from_hf_hub and "tokenizer_file" in init_kwargs:
         if from_hf_hub and "tokenizer_file" in init_kwargs:
             init_kwargs.pop("tokenizer_file")
             init_kwargs.pop("tokenizer_file")
 
 
-        # TODO(guosheng): avoid reduplication of position args and key word args
-        tokenizer = cls(*init_args, **init_kwargs)
+        try:
+            tokenizer = cls(*init_inputs, **init_kwargs)
+        # adapt to PaddleX
+        except RuntimeError as e:
+            if "sentencepiece_processor.cc" in str(e):
+                logging.info(
+                    "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead."
+                    "(SentencePiece RuntimeError: Tried to load SPM model with non-SPM vocab file).",
+                )
+            return False
+
         chat_template = init_kwargs.pop("chat_template", None)
         chat_template = init_kwargs.pop("chat_template", None)
         if chat_template is not None:
         if chat_template is not None:
             tokenizer.init_chat_template(chat_template)
             tokenizer.init_chat_template(chat_template)
@@ -1714,11 +1816,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                 special_tokens_map = json.load(special_tokens_map_handle)
                 special_tokens_map = json.load(special_tokens_map_handle)
             for key, value in special_tokens_map.items():
             for key, value in special_tokens_map.items():
                 if key in kwargs and kwargs[key]:
                 if key in kwargs and kwargs[key]:
-                    # This value has already been redefined by the kwargs
-                    # We keep this new value and ignore the one stored in the special_tokens_map_file
-
                     continue
                     continue
-
                 if isinstance(value, dict):
                 if isinstance(value, dict):
                     value = AddedToken(**value)
                     value = AddedToken(**value)
                 elif isinstance(value, list):
                 elif isinstance(value, list):
@@ -1727,13 +1825,15 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                         for token in value
                         for token in value
                     ]
                     ]
                 setattr(tokenizer, key, value)
                 setattr(tokenizer, key, value)
-        # Add supplementary tokens.
+                cls._add_extra_special_tokens(key)
+
         special_tokens = tokenizer.all_special_tokens
         special_tokens = tokenizer.all_special_tokens
+        added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
+        added_tokens_file = None if pass_added_tokens_file else added_tokens_file
         if added_tokens_file is not None:
         if added_tokens_file is not None:
             with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
             with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
                 added_tok_encoder = json.load(added_tokens_handle)
                 added_tok_encoder = json.load(added_tokens_handle)
 
 
-            # Sort added tokens by index
             added_tok_encoder_sorted = list(
             added_tok_encoder_sorted = list(
                 sorted(added_tok_encoder.items(), key=lambda x: x[1])
                 sorted(added_tok_encoder.items(), key=lambda x: x[1])
             )
             )
@@ -1743,14 +1843,11 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                     and index != len(tokenizer)
                     and index != len(tokenizer)
                     and tokenizer.convert_tokens_to_ids(token) != index
                     and tokenizer.convert_tokens_to_ids(token) != index
                 ):
                 ):
-                    # index is the current length of the tokenizer (not in vocabulary)
                     raise ValueError(
                     raise ValueError(
                         f"Wrong index found for {token}: should be {tokenizer.convert_tokens_to_ids(token)} but found "
                         f"Wrong index found for {token}: should be {tokenizer.convert_tokens_to_ids(token)} but found "
                         f"{index}."
                         f"{index}."
                     )
                     )
                 elif not has_tokenizer_file and index != len(tokenizer):
                 elif not has_tokenizer_file and index != len(tokenizer):
-                    # Tokenizer slow: added token cannot already be in the vocabulary so its index needs to be the
-                    # current length of the tokenizer.
                     raise ValueError(
                     raise ValueError(
                         f"Non-consecutive added token '{token}' found. "
                         f"Non-consecutive added token '{token}' found. "
                         f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary."
                         f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary."
@@ -1759,15 +1856,12 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                 tokenizer.add_tokens(
                 tokenizer.add_tokens(
                     token, special_tokens=bool(token in special_tokens)
                     token, special_tokens=bool(token in special_tokens)
                 )
                 )
-        # Check all our special tokens are registered as "no split" token (we don't cut them) and are in the vocab
         added_tokens = tokenizer.sanitize_special_tokens()
         added_tokens = tokenizer.sanitize_special_tokens()
         if added_tokens:
         if added_tokens:
             logging.info(
             logging.info(
                 "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained."
                 "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained."
             )
             )
-        # save all of related things into default root dir
         if pretrained_model_name_or_path in cls.pretrained_init_configuration:
         if pretrained_model_name_or_path in cls.pretrained_init_configuration:
-            # tokenizer.save_pretrained(os.path.join(cache_dir, pretrained_model_name_or_path, subfolder))
             tokenizer.save_pretrained(cache_dir)
             tokenizer.save_pretrained(cache_dir)
 
 
         if return_tokenizer_file_dir:
         if return_tokenizer_file_dir:
@@ -1826,7 +1920,6 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         for file_id in self.resource_files_names.keys():
         for file_id in self.resource_files_names.keys():
             tokenizer_config.pop(file_id, None)
             tokenizer_config.pop(file_id, None)
 
 
-        # Sanitize AddedTokens
         def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
         def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
             if isinstance(obj, AddedToken):
             if isinstance(obj, AddedToken):
                 out = obj.__getstate__()
                 out = obj.__getstate__()
@@ -1844,10 +1937,16 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                 }
                 }
             return obj
             return obj
 
 
-        # add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization
         tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True)
         tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True)
 
 
-        # Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained
+        added_tokens = {}
+        for key, value in self.added_tokens_decoder.items():
+            if isinstance(value, AddedToken):
+                added_tokens[key] = value.__getstate__()
+            else:
+                added_tokens[key] = AddedToken(value).__getstate__()
+        tokenizer_config["added_tokens_decoder"] = added_tokens
+
         tokenizer_class = self.__class__.__name__
         tokenizer_class = self.__class__.__name__
         tokenizer_config["tokenizer_class"] = tokenizer_class
         tokenizer_config["tokenizer_class"] = tokenizer_class
 
 
@@ -1855,7 +1954,6 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
             f.write(json.dumps(tokenizer_config, ensure_ascii=False))
             f.write(json.dumps(tokenizer_config, ensure_ascii=False))
         logging.info(f"tokenizer config file saved in {tokenizer_config_file}")
         logging.info(f"tokenizer config file saved in {tokenizer_config_file}")
 
 
-        # Sanitize AddedTokens in special_tokens_map
         write_dict = convert_added_tokens(
         write_dict = convert_added_tokens(
             self.special_tokens_map_extended, add_type_field=False
             self.special_tokens_map_extended, add_type_field=False
         )
         )
@@ -1945,8 +2043,6 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
         old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
         old_pad_to_max_length = kwargs.pop("pad_to_max_seq_len", False)
         old_pad_to_max_length = kwargs.pop("pad_to_max_seq_len", False)
 
 
-        # Backward compatibility for previous behavior, maybe we should deprecate it:
-        # If you only set max_length, it activates truncation for max_length
         if max_length is not None and padding is False and truncation is False:
         if max_length is not None and padding is False and truncation is False:
             if verbose:
             if verbose:
                 if not self.deprecation_warnings.get(
                 if not self.deprecation_warnings.get(
@@ -1991,7 +2087,6 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                         warnings.warn(
                         warnings.warn(
                             "Though `pad_to_max_length` = `True`, it is ignored because `padding`=`True`."
                             "Though `pad_to_max_length` = `True`, it is ignored because `padding`=`True`."
                         )
                         )
-                # Default to pad to the longest sequence in the batch
                 padding_strategy = PaddingStrategy.LONGEST
                 padding_strategy = PaddingStrategy.LONGEST
             elif not isinstance(padding, PaddingStrategy):
             elif not isinstance(padding, PaddingStrategy):
                 padding_strategy = PaddingStrategy(padding)
                 padding_strategy = PaddingStrategy(padding)
@@ -2105,6 +2200,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         return_offsets_mapping: bool = False,
         return_offsets_mapping: bool = False,
         add_special_tokens: bool = True,
         add_special_tokens: bool = True,
         pad_to_multiple_of: Optional[int] = None,
         pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[Literal["right", "left"]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         verbose: bool = True,
         verbose: bool = True,
         **kwargs,
         **kwargs,
@@ -2214,6 +2310,9 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                 If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
                 If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
                 the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
                 the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
                 Defaults to `None`.
                 Defaults to `None`.
+            padding_side (`str`, *optional*):
+                The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+                Default value is picked from the class attribute of the same name.
             return_tensors (str or [TensorType], optional):
             return_tensors (str or [TensorType], optional):
                 If set, will return tensors instead of list of python integers. Acceptable values are:
                 If set, will return tensors instead of list of python integers. Acceptable values are:
 
 
@@ -2332,6 +2431,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                 return_offsets_mapping=return_offsets_mapping,
                 return_offsets_mapping=return_offsets_mapping,
                 add_special_tokens=add_special_tokens,
                 add_special_tokens=add_special_tokens,
                 pad_to_multiple_of=pad_to_multiple_of,
                 pad_to_multiple_of=pad_to_multiple_of,
+                padding_side=padding_side,
                 return_tensors=return_tensors,
                 return_tensors=return_tensors,
                 verbose=verbose,
                 verbose=verbose,
                 **kwargs,
                 **kwargs,
@@ -2354,6 +2454,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                 return_offsets_mapping=return_offsets_mapping,
                 return_offsets_mapping=return_offsets_mapping,
                 add_special_tokens=add_special_tokens,
                 add_special_tokens=add_special_tokens,
                 pad_to_multiple_of=pad_to_multiple_of,
                 pad_to_multiple_of=pad_to_multiple_of,
+                padding_side=padding_side,
                 return_tensors=return_tensors,
                 return_tensors=return_tensors,
                 verbose=verbose,
                 verbose=verbose,
                 **kwargs,
                 **kwargs,
@@ -2370,6 +2471,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         stride: int = 0,
         stride: int = 0,
         is_split_into_words: bool = False,
         is_split_into_words: bool = False,
         pad_to_multiple_of: Optional[int] = None,
         pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[Literal["right", "left"]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_token_type_ids: Optional[bool] = None,
         return_token_type_ids: Optional[bool] = None,
         return_attention_mask: Optional[bool] = None,
         return_attention_mask: Optional[bool] = None,
@@ -2426,6 +2528,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
             stride=stride,
             stride=stride,
             is_split_into_words=is_split_into_words,
             is_split_into_words=is_split_into_words,
             pad_to_multiple_of=pad_to_multiple_of,
             pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
             return_tensors=return_tensors,
             return_tensors=return_tensors,
             return_position_ids=return_position_ids,
             return_position_ids=return_position_ids,
             return_token_type_ids=return_token_type_ids,
             return_token_type_ids=return_token_type_ids,
@@ -2448,6 +2551,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         max_length: Optional[int] = None,
         max_length: Optional[int] = None,
         stride: int = 0,
         stride: int = 0,
         is_split_into_words: bool = False,
         is_split_into_words: bool = False,
+        padding_side: Optional[Literal["right", "left"]] = None,
         pad_to_multiple_of: Optional[int] = None,
         pad_to_multiple_of: Optional[int] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_token_type_ids: Optional[bool] = None,
         return_token_type_ids: Optional[bool] = None,
@@ -2501,6 +2605,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
             stride=stride,
             stride=stride,
             is_split_into_words=is_split_into_words,
             is_split_into_words=is_split_into_words,
             pad_to_multiple_of=pad_to_multiple_of,
             pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
             return_tensors=return_tensors,
             return_tensors=return_tensors,
             return_token_type_ids=return_token_type_ids,
             return_token_type_ids=return_token_type_ids,
             return_attention_mask=return_attention_mask,
             return_attention_mask=return_attention_mask,
@@ -2523,6 +2628,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         stride: int = 0,
         stride: int = 0,
         is_split_into_words: bool = False,
         is_split_into_words: bool = False,
         pad_to_multiple_of: Optional[int] = None,
         pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[Literal["right", "left"]] = None,
         return_position_ids: Optional[bool] = None,
         return_position_ids: Optional[bool] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_token_type_ids: Optional[bool] = None,
         return_token_type_ids: Optional[bool] = None,
@@ -2562,6 +2668,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         return_offsets_mapping=False,
         return_offsets_mapping=False,
         add_special_tokens=True,
         add_special_tokens=True,
         pad_to_multiple_of: Optional[int] = None,
         pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[Literal["right", "left"]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         verbose: bool = True,
         verbose: bool = True,
         **kwargs,
         **kwargs,
@@ -2614,6 +2721,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
             stride=stride,
             stride=stride,
             is_split_into_words=is_split_into_words,
             is_split_into_words=is_split_into_words,
             pad_to_multiple_of=pad_to_multiple_of,
             pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
             return_tensors=return_tensors,
             return_tensors=return_tensors,
             return_position_ids=return_position_ids,
             return_position_ids=return_position_ids,
             return_token_type_ids=return_token_type_ids,
             return_token_type_ids=return_token_type_ids,
@@ -2644,6 +2752,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         stride: int = 0,
         stride: int = 0,
         is_split_into_words: bool = False,
         is_split_into_words: bool = False,
         pad_to_multiple_of: Optional[int] = None,
         pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[Literal["right", "left"]] = None,
         return_position_ids: Optional[bool] = None,
         return_position_ids: Optional[bool] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_token_type_ids: Optional[bool] = None,
         return_token_type_ids: Optional[bool] = None,
@@ -2669,6 +2778,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         ],
         ],
         padding: Union[bool, str, PaddingStrategy] = True,
         padding: Union[bool, str, PaddingStrategy] = True,
         max_length: Optional[int] = None,
         max_length: Optional[int] = None,
+        padding_side: Optional[Literal["right", "left"]] = None,
         pad_to_multiple_of: Optional[int] = None,
         pad_to_multiple_of: Optional[int] = None,
         return_attention_mask: Optional[bool] = None,
         return_attention_mask: Optional[bool] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
@@ -2713,6 +2823,9 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
 
 
                 This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
                 This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
                 >= 7.5 (Volta).
                 >= 7.5 (Volta).
+            padding_side (`str`, *optional*):
+                The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+                Default value is picked from the class attribute of the same name.
             return_attention_mask (`bool`, *optional*):
             return_attention_mask (`bool`, *optional*):
                 Whether to return the attention mask. If left to the default, will return the attention mask according
                 Whether to return the attention mask. If left to the default, will return the attention mask according
                 to the specific tokenizer's default, defined by the `return_outputs` attribute.
                 to the specific tokenizer's default, defined by the `return_outputs` attribute.
@@ -2781,13 +2894,28 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
 
 
         required_input = encoded_inputs[self.model_input_names[0]]
         required_input = encoded_inputs[self.model_input_names[0]]
         if required_input and not isinstance(required_input[0], (list, tuple)):
         if required_input and not isinstance(required_input[0], (list, tuple)):
-            encoded_inputs = self._pad(
-                encoded_inputs,
-                max_length=max_length,
-                padding_strategy=padding_strategy,
-                pad_to_multiple_of=pad_to_multiple_of,
-                return_attention_mask=return_attention_mask,
-            )
+            # some tokenizers might not have the padding_side attribute
+            if "padding_side" in set(inspect.signature(self._pad).parameters.keys()):
+                encoded_inputs = self._pad(
+                    encoded_inputs,
+                    max_length=max_length,
+                    padding_strategy=padding_strategy,
+                    pad_to_multiple_of=pad_to_multiple_of,
+                    padding_side=padding_side,
+                    return_attention_mask=return_attention_mask,
+                )
+            else:
+                original_padding_side = self.padding_side
+                self.padding_side = padding_side
+                encoded_inputs = self._pad(
+                    encoded_inputs,
+                    max_length=max_length,
+                    padding_strategy=padding_strategy,
+                    pad_to_multiple_of=pad_to_multiple_of,
+                    return_attention_mask=return_attention_mask,
+                )
+                self.padding_side = original_padding_side
+
             return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
             return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
 
 
         batch_size = len(required_input)
         batch_size = len(required_input)
@@ -2806,6 +2934,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                 inputs,
                 inputs,
                 max_length=max_length,
                 max_length=max_length,
                 padding_strategy=padding_strategy,
                 padding_strategy=padding_strategy,
+                padding_side=padding_side,
                 pad_to_multiple_of=pad_to_multiple_of,
                 pad_to_multiple_of=pad_to_multiple_of,
                 return_attention_mask=return_attention_mask,
                 return_attention_mask=return_attention_mask,
             )
             )
@@ -2888,6 +3017,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         max_length: Optional[int] = None,
         max_length: Optional[int] = None,
         stride: int = 0,
         stride: int = 0,
         pad_to_multiple_of: Optional[int] = None,
         pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[Literal["right", "left"]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_tensors: Optional[Union[str, TensorType]] = None,
         return_position_ids=None,
         return_position_ids=None,
         return_token_type_ids: Optional[bool] = None,
         return_token_type_ids: Optional[bool] = None,
@@ -3038,6 +3168,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                 max_length=max_length,
                 max_length=max_length,
                 padding=padding_strategy.value,
                 padding=padding_strategy.value,
                 pad_to_multiple_of=pad_to_multiple_of,
                 pad_to_multiple_of=pad_to_multiple_of,
+                padding_side=padding_side,
                 return_attention_mask=return_attention_mask,
                 return_attention_mask=return_attention_mask,
             )
             )
 
 
@@ -3190,6 +3321,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         max_length: Optional[int] = None,
         max_length: Optional[int] = None,
         padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
         padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
         pad_to_multiple_of: Optional[int] = None,
         pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[Literal["right", "left"]] = None,
         return_attention_mask: Optional[bool] = None,
         return_attention_mask: Optional[bool] = None,
     ) -> dict:
     ) -> dict:
         """
         """
@@ -3205,13 +3337,16 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                 - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
                 - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
                 - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
                 - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
                 - PaddingStrategy.DO_NOT_PAD: Do not pad
                 - PaddingStrategy.DO_NOT_PAD: Do not pad
-                The tokenizer padding sides are defined in self.padding_side:
+                The tokenizer padding sides are defined in `padding_side` argument:
 
 
                     - 'left': pads on the left of the sequences
                     - 'left': pads on the left of the sequences
                     - 'right': pads on the right of the sequences
                     - 'right': pads on the right of the sequences
             pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
             pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
                 This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
                 This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
                 >= 7.5 (Volta).
                 >= 7.5 (Volta).
+            padding_side: (optional) The side on which the model should have padding applied.
+                Should be selected between ['right', 'left'].
+                Default value is picked from the class attribute of the same name.
             return_attention_mask:
             return_attention_mask:
                 (optional) Set to False to avoid returning attention mask (default: set to model specifics)
                 (optional) Set to False to avoid returning attention mask (default: set to model specifics)
         """
         """
@@ -3245,12 +3380,33 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
 
 
         if needs_to_be_padded:
         if needs_to_be_padded:
             difference = max_length - len(required_input)
             difference = max_length - len(required_input)
+            padding_side = (
+                padding_side if padding_side is not None else self.padding_side
+            )
 
 
-            if self.padding_side == "right":
+            if padding_side == "right":
                 if return_attention_mask:
                 if return_attention_mask:
-
-                    encoded_inputs["attention_mask"] = (
-                        encoded_inputs["attention_mask"] + [0] * difference
+                    if len(np.shape(encoded_inputs["attention_mask"])) > 2:
+                        encoded_inputs["attention_mask"] = np.pad(
+                            encoded_inputs["attention_mask"],
+                            pad_width=[(0, 0), (0, difference), (0, difference)],
+                            mode="constant",
+                            constant_values=0,
+                        ).tolist()
+                    else:
+                        encoded_inputs["attention_mask"] = (
+                            encoded_inputs["attention_mask"] + [0] * difference
+                        )
+                if "attn_mask_startend_row_indices" in encoded_inputs:
+                    encoded_inputs["attn_mask_startend_row_indices"] = np.concatenate(
+                        [
+                            np.array(
+                                [encoded_inputs["attn_mask_startend_row_indices"]],
+                                dtype=np.int32,
+                            ),
+                            np.zeros([1, difference], dtype=np.int32),
+                        ],
+                        axis=-1,
                     )
                     )
                 if "token_type_ids" in encoded_inputs:
                 if "token_type_ids" in encoded_inputs:
                     encoded_inputs["token_type_ids"] = (
                     encoded_inputs["token_type_ids"] = (
@@ -3285,11 +3441,32 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                 encoded_inputs[self.model_input_names[0]] = (
                 encoded_inputs[self.model_input_names[0]] = (
                     required_input + [self.pad_token_id] * difference
                     required_input + [self.pad_token_id] * difference
                 )
                 )
-            elif self.padding_side == "left":
+            elif padding_side == "left":
                 if return_attention_mask:
                 if return_attention_mask:
-                    encoded_inputs["attention_mask"] = [
-                        0
-                    ] * difference + encoded_inputs["attention_mask"]
+                    if len(np.shape(encoded_inputs["attention_mask"])) > 2:
+                        # attention_mask shape [1,seq_len,seq_len]
+                        encoded_inputs["attention_mask"] = np.pad(
+                            encoded_inputs["attention_mask"],
+                            pad_width=[(0, 0), (difference, 0), (difference, 0)],
+                            mode="constant",
+                            constant_values=0,
+                        ).tolist()
+                    else:
+                        encoded_inputs["attention_mask"] = [
+                            0
+                        ] * difference + encoded_inputs["attention_mask"]
+                if "attn_mask_startend_row_indices" in encoded_inputs:
+                    encoded_inputs["attn_mask_startend_row_indices"] = np.concatenate(
+                        [
+                            np.zeros([1, difference], dtype=np.int32),
+                            np.array(
+                                [encoded_inputs["attn_mask_startend_row_indices"]],
+                                dtype=np.int32,
+                            )
+                            + difference,
+                        ],
+                        axis=-1,
+                    )
                 if "token_type_ids" in encoded_inputs:
                 if "token_type_ids" in encoded_inputs:
                     encoded_inputs["token_type_ids"] = [
                     encoded_inputs["token_type_ids"] = [
                         self.pad_token_type_id
                         self.pad_token_type_id
@@ -3323,6 +3500,15 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
                 ] * difference + required_input
                 ] * difference + required_input
             else:
             else:
                 raise ValueError("Invalid padding strategy:" + str(self.padding_side))
                 raise ValueError("Invalid padding strategy:" + str(self.padding_side))
+        else:
+            if "attn_mask_startend_row_indices" in encoded_inputs:
+                if len(np.shape(encoded_inputs["attn_mask_startend_row_indices"])) == 1:
+                    encoded_inputs["attn_mask_startend_row_indices"] = np.array([encoded_inputs["attn_mask_startend_row_indices"]], dtype=np.int32)  # fmt:skip
+
+        if "attn_mask_startend_row_indices" in encoded_inputs:
+            assert (
+                len(np.shape(encoded_inputs["attn_mask_startend_row_indices"])) == 2
+            )  # [num_head, seq_len]
 
 
         return encoded_inputs
         return encoded_inputs
 
 
@@ -3339,6 +3525,35 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
         """
         """
         raise NotImplementedError
         raise NotImplementedError
 
 
+    def decode_token(
+        self,
+        all_input_ids: List[int],
+        prefix_offset: int = 0,
+        read_offset: int = 0,
+    ) -> Tuple[str, int, int]:
+        """tokenizer decoding for the streaming generation use case. This method can be overrided for tokenizer that doesn't follow this API"""
+        prefix_text = self.decode(
+            all_input_ids[prefix_offset:read_offset],
+            skip_special_tokens=False,
+            clean_up_tokenization_spaces=False,
+        )
+        new_text = self.decode(
+            all_input_ids[prefix_offset:],
+            skip_special_tokens=False,
+            clean_up_tokenization_spaces=False,
+        )
+
+        if (
+            len(new_text) > len(prefix_text)
+            and not prefix_text.endswith("�")
+            and not new_text.endswith("�")
+        ):
+            prefix_index = new_text.index(prefix_text)
+            new_text = new_text[prefix_index + len(prefix_text) :]
+            return new_text, read_offset, len(all_input_ids)
+        else:
+            return "", prefix_offset, read_offset
+
     def batch_decode(
     def batch_decode(
         self,
         self,
         sequences,
         sequences,