diff --git a/setup.py b/setup.py index 5abd08fc6..28f71f789 100644 --- a/setup.py +++ b/setup.py @@ -153,7 +153,6 @@ extras_require = { "llmcompressor": [ "llmcompressor==0.5.1", ], - "mistral": ["mistral-common==1.5.6"], } install_requires, dependency_links, extras_require_build = parse_requirements( diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index c311d5247..08dcaac03 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -1,13 +1,17 @@ -"""Tokenizer loading functionality and associated utils""" +"""Tokenizer loading functionality and associated utils.""" import json import os +from typing import Any, Dict, List, Optional, Union +import torch import transformers -from transformers import ( - AddedToken, - AutoTokenizer, -) +from huggingface_hub import hf_hub_download +from mistral_common.protocol.instruct.messages import SystemMessage, UserMessage +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.tokenizers.base import SpecialTokens +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from transformers import AddedToken, AutoTokenizer from axolotl.integrations.base import PluginManager from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config @@ -23,239 +27,548 @@ from axolotl.utils.logging import get_logger LOG = get_logger(__name__) PLUGIN_MANAGER = PluginManager.get_instance() +# Constants +LLAMA_TOKENIZER_CLASSES = { + "LlamaTokenizer", + "LlamaTokenizerFast", + "CodeLlamaTokenizer", + "CodeLlamaTokenizerFast", +} +FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"} +MISTRAL_MODEL_TYPES = {"mistral", "mistral3"} -def modify_tokenizer_files( - tokenizer_path: str, token_mappings: dict[int, str], output_dir: str -) -> str: +QWEN_DEFAULT_TOKEN = "<|endoftext|>" # nosec B105 +GPTNEOX_PAD_TOKEN = "[PAD]" # nosec B105 +CHATML_DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant." + + +class MistralTokenizerWrapper: """ - Modify tokenizer files to replace added_tokens strings, save to output directory, - and return the path to the modified tokenizer. - - This only works with reserved tokens that were added to the tokenizer, not tokens - already part of the vocab. - - Args: - tokenizer_path: Path or name of the original tokenizer - token_mappings: Dict mapping {token_id (int): new_token_string} - output_dir: Directory to save the modified tokenizer - - Returns: - Path to the modified tokenizer directory - - Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941 + Wrapper to make MistralTokenizer compatible with Hugging Face tokenizer + interface. This provides a bridge between Mistral's native tokenizer and axolotl's + expectations. """ - # Create the tokenizer directory in output_dir if it doesn't exist - tokenizer_dir = os.path.join(output_dir, "tokenizer") - os.makedirs(tokenizer_dir, exist_ok=True) - if is_local_main_process(): # pylint: disable=too-many-nested-blocks - # Load the tokenizer - temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) + def __init__( + self, + mistral_tokenizer: MistralTokenizer, + model_id: str, + system_prompt: str | None = None, + ): + self.mistral_tokenizer = mistral_tokenizer + self.model_id = model_id + self.system_prompt = system_prompt + self.padding_side = "right" # Default padding side + self.chat_template = None - # Save the tokenizer to the output directory - temp_tokenizer.save_pretrained(tokenizer_dir) + # pylint: disable=unused-argument + def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> List[int]: + """Encode text to token IDs""" + # For simple string encoding, create a user message + messages = [] + if self.system_prompt and add_special_tokens: + messages.append(SystemMessage(content=self.system_prompt)) + messages.append(UserMessage(content=text)) - # Get the token IDs and map them to their new values + tokenized = self.mistral_tokenizer.encode_chat_completion( + ChatCompletionRequest(messages=messages) + ) + return tokenized.tokens + + def decode( + self, + token_ids: Union[List[int], torch.Tensor], + skip_special_tokens: bool = True, # pylint: disable=unused-argument + ) -> str: + """Decode token IDs to text""" + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.tolist() + return self.mistral_tokenizer.decode(token_ids) + + def __call__(self, text: str, **kwargs): + """Make the tokenizer callable like HF tokenizers""" + tokens = self.encode(text, **kwargs) + return {"input_ids": torch.tensor([tokens])} + + @property + def special_tokens_reverse_vocab(self): + # pylint: disable=protected-access + return ( + self.mistral_tokenizer.instruct_tokenizer.tokenizer._special_tokens_reverse_vocab + ) + + @property + def eos_token(self): + return SpecialTokens.eos + + @property + def bos_token(self): + return SpecialTokens.bos + + @property + def pad_token(self): + return self.eos_token # Use EOS as pad token + + @property + def unk_token(self): + return SpecialTokens.unk + + @property + def eos_token_id(self): + return self.special_tokens_reverse_vocab[self.eos_token] + + @property + def bos_token_id(self): + return self.special_tokens_reverse_vocab[self.bos_token] + + @property + def pad_token_id(self): + return self.special_tokens_reverse_vocab[self.pad_token] + + @property + def unk_token_id(self): + return self.special_tokens_reverse_vocab[self.unk_token] + + # pylint: disable=unused-argument + def add_special_tokens(self, special_tokens_dict: Dict[str, str]) -> int: + """Placeholder for special token addition - Mistral tokenizer handles this internally""" + LOG.warning( + "add_special_tokens called on MistralTokenizer wrapper - this is handled internally" + ) + return 0 + + # pylint: disable=unused-argument + def add_tokens(self, tokens) -> int: + """Placeholder for token addition - Mistral tokenizer handles this internally""" + LOG.warning( + "add_tokens called on MistralTokenizer wrapper - this is handled internally" + ) + return 0 + + +class TokenizerFileModifier: + """Handles modification of tokenizer files for token overrides.""" + + def __init__( + self, tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str + ): + self.tokenizer_path = tokenizer_path + self.token_mappings = token_mappings + self.output_dir = output_dir + self.tokenizer_dir = os.path.join(output_dir, "tokenizer") + + def modify_and_save(self) -> str: + """Modify tokenizer files and return path to modified tokenizer.""" + os.makedirs(self.tokenizer_dir, exist_ok=True) + + if is_local_main_process(): + self._perform_modifications() + barrier() + + return self.tokenizer_dir + + def _perform_modifications(self): + """Perform the actual file modifications.""" + # Load and save tokenizer to output directory + temp_tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_path, use_fast=True + ) + temp_tokenizer.save_pretrained(self.tokenizer_dir) + + # Convert token mappings to proper format token_id_mappings = { - int(token_id): new_value for token_id, new_value in token_mappings.items() + int(token_id): new_value + for token_id, new_value in self.token_mappings.items() } - # 1. Update tokenizer_config.json - added_tokens_decoder - config_path = os.path.join(tokenizer_dir, "tokenizer_config.json") - if os.path.exists(config_path): - with open(config_path, "r", encoding="utf-8") as f: - config_data = json.load(f) + # Update both tokenizer files + self._update_tokenizer_config(token_id_mappings) + self._update_tokenizer_json(token_id_mappings) - # Update added_tokens_decoder - if "added_tokens_decoder" in config_data: - for token_id, new_value in token_id_mappings.items(): - token_id_str = str(token_id) - if token_id_str in config_data["added_tokens_decoder"]: - config_data["added_tokens_decoder"][token_id_str][ - "content" - ] = new_value - else: - raise ValueError( - f"Token ID {token_id_str} not found in added_tokens_decoder" - ) + def _update_tokenizer_config(self, token_id_mappings: Dict[int, str]): + """Update tokenizer_config.json with new token mappings.""" + config_path = os.path.join(self.tokenizer_dir, "tokenizer_config.json") + if not os.path.exists(config_path): + return - # Write the updated config back - with open(config_path, "w", encoding="utf-8") as f: - json.dump(config_data, f, indent=2) + with open(config_path, "r", encoding="utf-8") as f: + config_data = json.load(f) - # 2. Update tokenizer.json - added_tokens - tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") - if os.path.exists(tokenizer_path): - with open(tokenizer_path, "r", encoding="utf-8") as f: - tokenizer_data = json.load(f) + if "added_tokens_decoder" in config_data: + self._update_added_tokens_decoder(config_data, token_id_mappings) - # Update added_tokens - if "added_tokens" in tokenizer_data: - for token_id, new_value in token_id_mappings.items(): - for i, token_entry in enumerate(tokenizer_data["added_tokens"]): - if token_entry["id"] == token_id: - tokenizer_data["added_tokens"][i]["content"] = new_value - break - else: - # Reaching this section means the token_id was not found in tokenizer.json added_tokens - raise ValueError( - f"Token ID {token_id} not found in added_tokens" - ) - if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]: - for token_id, new_value in token_id_mappings.items(): - for entry_val, entry_id in tokenizer_data["model"]["vocab"].items(): - if entry_id == token_id: - del tokenizer_data["model"]["vocab"][entry_val] - tokenizer_data["model"]["vocab"][new_value] = token_id - break + with open(config_path, "w", encoding="utf-8") as f: + json.dump(config_data, f, indent=2) - # Write the updated tokenizer data back - with open(tokenizer_path, "w", encoding="utf-8") as f: - json.dump(tokenizer_data, f, indent=2) + def _update_added_tokens_decoder( + self, config_data: Dict, token_id_mappings: Dict[int, str] + ): + """Update the added_tokens_decoder section.""" + for token_id, new_value in token_id_mappings.items(): + token_id_str = str(token_id) + if token_id_str in config_data["added_tokens_decoder"]: + config_data["added_tokens_decoder"][token_id_str]["content"] = new_value + else: + raise ValueError( + f"Token ID {token_id_str} not found in added_tokens_decoder" + ) - barrier() - return tokenizer_dir + def _update_tokenizer_json(self, token_id_mappings: Dict[int, str]): + """Update tokenizer.json with new token mappings.""" + tokenizer_json_path = os.path.join(self.tokenizer_dir, "tokenizer.json") + if not os.path.exists(tokenizer_json_path): + return + + with open(tokenizer_json_path, "r", encoding="utf-8") as f: + tokenizer_data = json.load(f) + + self._update_added_tokens_list(tokenizer_data, token_id_mappings) + self._update_vocab_mappings(tokenizer_data, token_id_mappings) + + with open(tokenizer_json_path, "w", encoding="utf-8") as f: + json.dump(tokenizer_data, f, indent=2) + + def _update_added_tokens_list( + self, tokenizer_data: Dict, token_id_mappings: Dict[int, str] + ): + """Update the added_tokens list in tokenizer.json.""" + if "added_tokens" not in tokenizer_data: + return + + for token_id, new_value in token_id_mappings.items(): + for i, token_entry in enumerate(tokenizer_data["added_tokens"]): + if token_entry["id"] == token_id: + tokenizer_data["added_tokens"][i]["content"] = new_value + break + else: + raise ValueError(f"Token ID {token_id} not found in added_tokens") + + def _update_vocab_mappings( + self, tokenizer_data: Dict, token_id_mappings: Dict[int, str] + ): + """Update vocab mappings in tokenizer.json.""" + if not (tokenizer_data.get("model") and tokenizer_data["model"].get("vocab")): + return + + vocab = tokenizer_data["model"]["vocab"] + for token_id, new_value in token_id_mappings.items(): + # Find and update the vocab entry + for entry_val, entry_id in list(vocab.items()): + if entry_id == token_id: + del vocab[entry_val] + vocab[new_value] = token_id + break + + +class TokenizerConfiguration: + """Handles tokenizer configuration and initialization.""" + + def __init__(self, cfg): + self.cfg = cfg + self.model_config = load_model_config(cfg) + + def should_use_mistral_tokenizer(self) -> bool: + """Determine if Mistral tokenizer should be used.""" + # Explicit configuration + return self.model_config.model_type in MISTRAL_MODEL_TYPES + + def load_mistral_tokenizer(self) -> MistralTokenizerWrapper: + """Load Mistral tokenizer from model configuration.""" + model_id = getattr(self.cfg, "model_name_or_path", None) or getattr( + self.cfg, "base_model", None + ) + if not model_id: + raise ValueError( + "model_name_or_path or base_model must be specified for Mistral tokenizer" + ) + + try: + # Download the tekken.json file for the tokenizer + tekken_file = hf_hub_download(repo_id=model_id, filename="tekken.json") + + # Load the Mistral tokenizer + mistral_tokenizer = MistralTokenizer.from_file(tekken_file) + + # Wrap it for compatibility + wrapped_tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id) + + LOG.info(f"Loaded Mistral tokenizer for model: {model_id}") + return wrapped_tokenizer + + except Exception as e: + LOG.error(f"Failed to load Mistral tokenizer: {e}") + raise + + def get_tokenizer_class(self): + """Get the appropriate tokenizer class.""" + if self.cfg.tokenizer_type: + return getattr(transformers, self.cfg.tokenizer_type) + return AutoTokenizer + + def get_tokenizer_kwargs(self) -> Dict[str, Any]: + """Build tokenizer initialization kwargs.""" + kwargs = {} + if self.cfg.tokenizer_legacy is not None: + kwargs["legacy"] = self.cfg.tokenizer_legacy + return kwargs + + def get_tokenizer_path(self) -> str: + """Get the tokenizer path, applying overrides if needed.""" + tokenizer_path = self.cfg.tokenizer_config + + if self.cfg.added_tokens_overrides: + modifier = TokenizerFileModifier( + tokenizer_path, self.cfg.added_tokens_overrides, self.cfg.output_dir + ) + tokenizer_path = modifier.modify_and_save() + + return tokenizer_path + + def should_use_fast_tokenizer(self) -> bool: + """Determine if fast tokenizer should be used.""" + return ( + self.cfg.tokenizer_use_fast + if self.cfg.tokenizer_use_fast is not None + else True + ) + + +class TokenizerPostProcessor: + """Handles post-processing configuration of loaded tokenizers.""" + + def __init__(self, tokenizer, cfg): + self.tokenizer = tokenizer + self.cfg = cfg + self.model_config = load_model_config(cfg) + + def apply_all_configurations(self): + """Apply all post-processing configurations to the tokenizer.""" + # Skip most configurations for Mistral wrapper + if isinstance(self.tokenizer, MistralTokenizerWrapper): + self._configure_mistral_wrapper() + return + + self._configure_padding_token() + self._configure_gptneox_settings() + self._configure_mistral_padding() + self._configure_qwen_tokens() + self._add_special_tokens() + self._add_regular_tokens() + self._configure_chat_template() + + def _configure_mistral_wrapper(self): + """Apply limited configurations for Mistral wrapper.""" + # Set padding side if needed + if ( + self.cfg.is_mistral_derived_model + and self.cfg.flash_attention + and not self.cfg.sample_packing + ): + self.tokenizer.padding_side = "left" + + # Configure chat template for Mistral + self._configure_chat_template() + + def _configure_padding_token(self): + """Configure padding token for Llama-based tokenizers.""" + if ( + self.tokenizer.__class__.__name__ in LLAMA_TOKENIZER_CLASSES + and hasattr(self.tokenizer, "pad_token") + and not self.tokenizer.pad_token + ): + self.tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN + + def _configure_gptneox_settings(self): + """Configure GPTNeoX-specific settings.""" + if self.tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": + self.tokenizer.add_special_tokens({"pad_token": GPTNEOX_PAD_TOKEN}) + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + def _configure_mistral_padding(self): + """Configure left padding for Mistral models with Flash Attention.""" + if ( + self.cfg.is_mistral_derived_model + and self.cfg.flash_attention + and not self.cfg.sample_packing + ): + self.tokenizer.padding_side = "left" + + def _configure_qwen_tokens(self): + """Configure special tokens for Qwen models.""" + if not self.cfg.is_qwen_derived_model: + return + + # Set token IDs + token_id_attributes = [ + "bos_token_id", + "eos_token_id", + "pad_token_id", + "unk_token_id", + ] + for attr_name in token_id_attributes: + if getattr(self.tokenizer, attr_name) is None: + setattr(self.tokenizer, attr_name, self.tokenizer.eod_id) + + # Set token strings + token_name_attributes = ["bos_token", "eos_token", "pad_token", "unk_token"] + for attr_name in token_name_attributes: + if getattr(self.tokenizer, attr_name) is None: + setattr(self.tokenizer, attr_name, QWEN_DEFAULT_TOKEN) + + def _add_special_tokens(self): + """Add special tokens from configuration.""" + if not self.cfg.special_tokens: + return + + special_tokens_dict = self.cfg.special_tokens.to_dict() + additional_special_tokens = special_tokens_dict.pop( + "additional_special_tokens", None + ) + + self._validate_and_add_special_tokens(special_tokens_dict) + self._update_post_processor_if_needed(special_tokens_dict) + self._add_additional_special_tokens_if_present(additional_special_tokens) + + def _validate_and_add_special_tokens(self, special_tokens: Dict[str, str]): + """Validate special tokens for adapter training and add them.""" + lora_modules_to_save = get_linear_embedding_layers(self.model_config.model_type) + + for key, value in special_tokens.items(): + self._validate_token_for_adapter(key, value, lora_modules_to_save) + self.tokenizer.add_special_tokens( + {key: AddedToken(value, rstrip=False, lstrip=False, normalized=False)} + ) + + def _validate_token_for_adapter( + self, key: str, value: str, lora_modules_to_save: List[str] + ): + """Validate a single token for adapter training requirements.""" + if not self._should_validate_token_for_adapter( + key, value, lora_modules_to_save + ): + return + + modules_str = ", ".join(f"`{x}`" for x in lora_modules_to_save) + raise ValueError( + f"Please set lora_modules_to_save to [{modules_str}] " + f"when using an adapter and changing the special tokens." + ) + + def _should_validate_token_for_adapter( + self, key: str, value: str, lora_modules_to_save: List[str] + ) -> bool: + """Check if token should be validated for adapter configuration.""" + if key == "pad_token" or not self.cfg.adapter: + return False + + current_token = getattr(self.tokenizer, key) + token_changed = current_token is None or current_token != value + token_is_multi_char = ( + len(self.tokenizer.encode(value, add_special_tokens=False)) > 2 + ) + lora_modules_missing = not self.cfg.lora_modules_to_save or not all( + x in self.cfg.lora_modules_to_save for x in lora_modules_to_save + ) + + return token_changed and token_is_multi_char and lora_modules_missing + + def _update_post_processor_if_needed(self, special_tokens: Dict[str, str]): + """Update post processor for Llama tokenizers when BOS/EOS tokens are added.""" + has_bos_and_eos = ( + "bos_token" in special_tokens and "eos_token" in special_tokens + ) + is_fast_llama = ( + self.tokenizer.__class__.__name__ in FAST_LLAMA_TOKENIZER_CLASSES + ) + + if is_fast_llama and has_bos_and_eos: + self.tokenizer.update_post_processor() + + def _add_additional_special_tokens_if_present( + self, additional_special_tokens: Optional[List[str]] + ): + """Add additional special tokens if they exist.""" + if additional_special_tokens is not None: + self.tokenizer.add_special_tokens( + {"additional_special_tokens": additional_special_tokens} + ) + + def _add_regular_tokens(self): + """Add regular (non-special) tokens from configuration.""" + if self.cfg.tokens: + self.tokenizer.add_tokens( + [ + AddedToken(token, rstrip=False, lstrip=False, normalized=False) + for token in self.cfg.tokens + ] + ) + + def _configure_chat_template(self): + """Configure chat template if specified.""" + if not self.cfg.chat_template: + LOG.info( + "No Chat template selected. Consider adding a chat template for easier inference." + ) + return + + chat_template_string = get_chat_template_from_config( + cfg=self.cfg, + tokenizer=self.tokenizer, + ) + + if self._should_replace_default_system_message(): + chat_template_string = chat_template_string.replace( + CHATML_DEFAULT_SYSTEM_MESSAGE, self.cfg.default_system_message + ) + + self.tokenizer.chat_template = chat_template_string + + def _should_replace_default_system_message(self) -> bool: + """Check if default system message should be replaced.""" + return self.cfg.default_system_message and self.cfg.chat_template == "chatml" def load_tokenizer(cfg): - """Load and configure the tokenizer based on the provided config.""" - model_config = load_model_config(cfg) - tokenizer_kwargs = {} - use_fast = True # this is the default + """Load and configure the tokenizer based on the provided config. - if cfg.tokenizer_use_fast is not None: - use_fast = cfg.tokenizer_use_fast - if cfg.tokenizer_legacy is not None: - # True is the default w/ https://github.com/huggingface/transformers/pull/25224 - tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy + This function handles the complete tokenizer loading pipeline: + - Check if Mistral tokenizer should be used + - Configure tokenizer parameters and get the appropriate class + - Handle token file modifications if needed + - Initialize the tokenizer with the correct parameters + - Apply all post-processing configurations (padding, special tokens, etc.) + - Set up chat templates and logging - tokenizer_cls = AutoTokenizer - if cfg.tokenizer_type: - tokenizer_cls = getattr(transformers, cfg.tokenizer_type) + Args: + cfg: Dictionary mapping `axolotl` config keys to values. - # Set base tokenizer path - tokenizer_path = cfg.tokenizer_config + Returns: + Fully configured tokenizer instance. + """ + # Configure tokenizer parameters + config = TokenizerConfiguration(cfg) - # Apply token string overrides if specified - if cfg.added_tokens_overrides: - # Modify tokenizer files and get path to modified tokenizer - tokenizer_path = modify_tokenizer_files( - tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir + # Check if we should use Mistral tokenizer + if config.should_use_mistral_tokenizer(): + tokenizer = config.load_mistral_tokenizer() + else: + # Standard tokenizer loading + tokenizer_cls = config.get_tokenizer_class() + tokenizer_path = config.get_tokenizer_path() + use_fast = config.should_use_fast_tokenizer() + tokenizer_kwargs = config.get_tokenizer_kwargs() + + # Initialize the tokenizer + tokenizer = tokenizer_cls.from_pretrained( + tokenizer_path, + trust_remote_code=cfg.trust_remote_code or False, + use_fast=use_fast, + **tokenizer_kwargs, ) - tokenizer = tokenizer_cls.from_pretrained( - tokenizer_path, - trust_remote_code=cfg.trust_remote_code or False, - use_fast=use_fast, - **tokenizer_kwargs, - ) - - if ( - tokenizer.__class__.__name__ - in [ - "LlamaTokenizer", - "LlamaTokenizerFast", - "CodeLlamaTokenizer", - "CodeLlamaTokenizerFast", - ] - and hasattr(tokenizer, "pad_token") - and not tokenizer.pad_token - ): - # set a pad_token, but use eos_token so we don't add a new token - tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN - - if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - # Mistral's official FA implementation requires left padding - if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing: - tokenizer.padding_side = "left" - - # Qwen base only has single token, so we need to set the special tokens - if cfg.is_qwen_derived_model: - token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"] - for attr_name in token_ids: - if getattr(tokenizer, attr_name) is None: - setattr(tokenizer, attr_name, tokenizer.eod_id) - - token_names = ["bos_token", "eos_token", "pad_token", "unk_token"] - for attr_name in token_names: - if getattr(tokenizer, attr_name) is None: - setattr(tokenizer, attr_name, "<|endoftext|>") - - additional_special_tokens = None - if cfg.special_tokens: - special_tokens = cfg.special_tokens.to_dict() - additional_special_tokens = special_tokens.pop( - "additional_special_tokens", None - ) - lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) - for k, val in special_tokens.items(): - # check if new special token is not already in tokenizer and - # is adapter training to make sure lora_modules_to_save is set - # pylint: disable=too-many-boolean-expressions - if ( - (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) - and (len(tokenizer.encode(val, add_special_tokens=False)) > 2) - and cfg.adapter - and ( - not cfg.lora_modules_to_save - or not all( - x in cfg.lora_modules_to_save for x in lora_modules_to_save - ) - ) - and k != "pad_token" - ): - lora_modules_to_save = ", ".join( - [f"`{x}`" for x in lora_modules_to_save] - ) - raise ValueError( - f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens." - ) - - tokenizer.add_special_tokens( - {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} - ) - - # If we add bos_token and eos_token, we need to update the post processor to - # handle them correctly. - # https://github.com/huggingface/transformers/pull/24132 - bos_or_eos_in_special_tokens = ( - "bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens - ) - if ( - tokenizer.__class__.__name__ - in ( - "LlamaTokenizerFast", - "CodeLlamaTokenizerFast", - ) - and bos_or_eos_in_special_tokens - ): - tokenizer.update_post_processor() - - if cfg.tokens: - tokenizer.add_tokens( - [ - AddedToken(token, rstrip=False, lstrip=False, normalized=False) - for token in cfg.tokens - ] - ) - - # Additional special tokens are a List, and need to be treated differently than regular special - # tokens. We add them after we have called `add_tokens` in case these additional special tokens - # are new tokens. - # - # Usage: - # - # ```py - # special_tokens: - # additional_special_tokens: ["<|im_start|>", "<|im_end|>"] - # ``` - if additional_special_tokens is not None: - tokenizer.add_special_tokens( - {"additional_special_tokens": additional_special_tokens} - ) + # Apply all post-processing configurations + post_processor = TokenizerPostProcessor(tokenizer, cfg) + post_processor.apply_all_configurations() if is_main_process(use_environ=True): LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") @@ -263,19 +576,4 @@ def load_tokenizer(cfg): LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") - if cfg.chat_template: - chat_template_string = get_chat_template_from_config( - cfg=cfg, - tokenizer=tokenizer, - ) - if cfg.default_system_message and cfg.chat_template == "chatml": - chat_template_string = chat_template_string.replace( - "You are a helpful assistant.", cfg.default_system_message - ) - - tokenizer.chat_template = chat_template_string - else: - LOG.info( - "No Chat template selected. Consider adding a chat template for easier inference." - ) return tokenizer