diff --git a/docs/config.qmd b/docs/config.qmd index cfd137ff0..fb0c4b59b 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -154,8 +154,6 @@ datasets: content: value # ... - message_property_mappings: - # Optional[Dict[str, List]]. Roles mapping in the messages. The default is: roles: user: ["human", "user"] @@ -556,6 +554,13 @@ special_tokens: # Add extra tokens. tokens: +# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer. +# Only works for tokens that are not part of the base vocab (aka are added_tokens). +# Can be checked if they exist in tokenizer.json added_tokens. +added_tokens_overrides: # Dict[int, str] +# 128041: "<|im_start|>" +# 128042: "<|im_end|>" + # FSDP fsdp: fsdp_config: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 180e02823..ce2586afb 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -855,6 +855,7 @@ class AxolotlInputConfig( special_tokens: Optional[SpecialTokensConfig] = None tokens: Optional[List[str]] = None + added_tokens_overrides: Optional[Dict[int, str]] = None torch_compile: Optional[Union[Literal["auto"], bool]] = None torch_compile_backend: Optional[str] = None diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 81a928b6e..7d6cd597a 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -79,7 +79,7 @@ def is_main_process(): def is_local_main_process(): - return PartialState().is_main_process + return PartialState().is_local_main_process def get_world_size(): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c4c07dd33..add690d9d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -57,7 +57,13 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import get_device_count, get_device_type, zero_only +from axolotl.utils.distributed import ( + barrier, + get_device_count, + get_device_type, + is_local_main_process, + zero_only, +) from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant @@ -165,7 +171,95 @@ def load_model_config(cfg): return model_config +def modify_tokenizer_files( + tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str +) -> str: + """ + Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer. + + This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab. + + Args: + tokenizer_path: Path or name of the original tokenizer + token_mappings: Dict mapping {token_id (int): new_token_string} + output_dir: Directory to save the modified tokenizer + + Returns: + Path to the modified tokenizer directory + + Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941 + """ + + import json + + # Create the tokenizer directory in output_dir if it doesn't exist + tokenizer_dir = os.path.join(output_dir, "tokenizer") + os.makedirs(tokenizer_dir, exist_ok=True) + + if is_local_main_process(): # pylint: disable=too-many-nested-blocks + # Load the tokenizer + temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) + + # Save the tokenizer to the output directory + temp_tokenizer.save_pretrained(tokenizer_dir) + + # Get the token IDs and map them to their new values + token_id_mappings = { + int(token_id): new_value for token_id, new_value in token_mappings.items() + } + + # 1. Update tokenizer_config.json - added_tokens_decoder + config_path = os.path.join(tokenizer_dir, "tokenizer_config.json") + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as f: + config_data = json.load(f) + + # Update added_tokens_decoder + if "added_tokens_decoder" in config_data: + for token_id, new_value in token_id_mappings.items(): + token_id_str = str(token_id) + if token_id_str in config_data["added_tokens_decoder"]: + config_data["added_tokens_decoder"][token_id_str][ + "content" + ] = new_value + else: + raise ValueError( + f"Token ID {token_id_str} not found in added_tokens_decoder" + ) + + # Write the updated config back + with open(config_path, "w", encoding="utf-8") as f: + json.dump(config_data, f, indent=2) + + # 2. Update tokenizer.json - added_tokens + tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") + if os.path.exists(tokenizer_path): + with open(tokenizer_path, "r", encoding="utf-8") as f: + tokenizer_data = json.load(f) + + # Update added_tokens + if "added_tokens" in tokenizer_data: + for token_id, new_value in token_id_mappings.items(): + for i, token_entry in enumerate(tokenizer_data["added_tokens"]): + if token_entry["id"] == token_id: + tokenizer_data["added_tokens"][i]["content"] = new_value + break + else: + # Reaching this section means the token_id was not found in tokenizer.json added_tokens + raise ValueError( + f"Token ID {token_id} not found in added_tokens" + ) + + # Write the updated tokenizer data back + with open(tokenizer_path, "w", encoding="utf-8") as f: + json.dump(tokenizer_data, f, indent=2) + + barrier() + return tokenizer_dir + + def load_tokenizer(cfg): + """Load and configure the tokenizer based on the provided config.""" model_config = load_model_config(cfg) tokenizer_kwargs = {} use_fast = True # this is the default @@ -180,8 +274,18 @@ def load_tokenizer(cfg): if cfg.tokenizer_type: tokenizer_cls = getattr(transformers, cfg.tokenizer_type) + # Set base tokenizer path + tokenizer_path = cfg.tokenizer_config + + # Apply token string overrides if specified + if cfg.added_tokens_overrides: + # Modify tokenizer files and get path to modified tokenizer + tokenizer_path = modify_tokenizer_files( + tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir + ) + tokenizer = tokenizer_cls.from_pretrained( - cfg.tokenizer_config, + tokenizer_path, trust_remote_code=cfg.trust_remote_code or False, use_fast=use_fast, **tokenizer_kwargs, diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py index 69c441f8c..3d568ab19 100644 --- a/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -1,6 +1,7 @@ """ Test cases for the tokenizer loading """ + import unittest import pytest @@ -9,7 +10,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_tokenizer -class TestTokenizers(unittest.TestCase): +class TestTokenizers: """ test class for the load_tokenizer fn """ @@ -75,12 +76,48 @@ class TestTokenizers(unittest.TestCase): } ) tokenizer = load_tokenizer(cfg) - self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404]) - self.assertEqual(len(tokenizer), 32001) + assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404] + assert len(tokenizer) == 32001 # ensure reloading the tokenizer again from cfg results in same vocab length tokenizer = load_tokenizer(cfg) - self.assertEqual(len(tokenizer), 32001) + assert len(tokenizer) == 32001 + + def test_added_tokens_overrides(self, temp_dir): + cfg = DictDefault( + { + # use with tokenizer that has reserved_tokens in added_tokens + "tokenizer_config": "NousResearch/Llama-3.2-1B", + "added_tokens_overrides": { + 128041: "RANDOM_OVERRIDE_1", + 128042: "RANDOM_OVERRIDE_2", + }, + "output_dir": temp_dir, + } + ) + + tokenizer = load_tokenizer(cfg) + assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [ + 128041 + ] + assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [ + 128042 + ] + + def test_added_tokens_overrides_with_toolargeid(self, temp_dir): + cfg = DictDefault( + { + # use with tokenizer that has reserved_tokens in added_tokens + "tokenizer_config": "NousResearch/Llama-3.2-1B", + "added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"}, + "output_dir": temp_dir, + } + ) + + with pytest.raises( + ValueError, match=r".*Token ID 1000000 not found in added_tokens.*" + ): + load_tokenizer(cfg) if __name__ == "__main__":