From 949471039fd56dc33a5485605d15610848f3cc48 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 5 Apr 2025 01:25:44 -0400 Subject: [PATCH] fix tokenizer overrides w gemma3 (#2488) * fix tokenizer overrides w gemma3 * fix offline wrapping --- src/axolotl/utils/models.py | 7 +++++++ tests/conftest.py | 15 ++++++++++++--- tests/hf_offline_utils.py | 2 +- tests/test_tokenizers.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 48 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 9df63231b..301607865 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -283,6 +283,13 @@ def modify_tokenizer_files( raise ValueError( f"Token ID {token_id} not found in added_tokens" ) + if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]: + for token_id, new_value in token_id_mappings.items(): + for entry_val, entry_id in tokenizer_data["model"]["vocab"].items(): + if entry_id == token_id: + del tokenizer_data["model"]["vocab"][entry_val] + tokenizer_data["model"]["vocab"][new_value] = token_id + break # Write the updated tokenizer data back with open(tokenizer_path, "w", encoding="utf-8") as f: diff --git a/tests/conftest.py b/tests/conftest.py index c71ea1e8c..97c48db41 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,6 @@ from tokenizers import AddedToken from transformers import AutoTokenizer from tests.hf_offline_utils import ( - disable_hf_offline, enable_hf_offline, hf_offline_context, ) @@ -50,7 +49,6 @@ def retry_on_request_exceptions(max_retries=3, delay=1): @retry_on_request_exceptions(max_retries=3, delay=5) -@disable_hf_offline def snapshot_download_w_retry(*args, **kwargs): """ download a model or dataset from HF Hub, retrying in requests failures. We also try to fetch it from the local @@ -62,7 +60,8 @@ def snapshot_download_w_retry(*args, **kwargs): return snapshot_download(*args, **kwargs) except LocalEntryNotFoundError: pass - return snapshot_download(*args, **kwargs) + with hf_offline_context(False): + return snapshot_download(*args, **kwargs) @pytest.fixture(scope="session", autouse=True) @@ -265,6 +264,16 @@ def download_mistral_7b_model_fixture(): ) +@pytest.fixture(scope="session", autouse=True) +def download_gemma3_4b_model_fixture(): + # download the tokenizer only + snapshot_download_w_retry( + "mlx-community/gemma-3-4b-it-8bit", + repo_type="model", + allow_patterns=["*token*", "config.json"], + ) + + @pytest.fixture(scope="session", autouse=True) def download_gemma_2b_model_fixture(): # download the tokenizer only diff --git a/tests/hf_offline_utils.py b/tests/hf_offline_utils.py index 0c7b5d4a4..385e61f18 100644 --- a/tests/hf_offline_utils.py +++ b/tests/hf_offline_utils.py @@ -95,7 +95,7 @@ def hf_offline_context(hf_hub_offline): """ original_hf_offline = os.getenv("HF_HUB_OFFLINE") os.environ["HF_HUB_OFFLINE"] = str(hf_hub_offline) - reload_modules(True) + reload_modules(bool(hf_hub_offline)) yield # Restore the original value of HF_HUB_OFFLINE environment variable if original_hf_offline is not None: diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py index ef0cb14d1..ffd51bc29 100644 --- a/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -110,6 +110,34 @@ class TestTokenizers: assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [ 128042 ] + assert ( + tokenizer.decode([128041, 128042]) == "RANDOM_OVERRIDE_1RANDOM_OVERRIDE_2" + ) + + @enable_hf_offline + def test_added_tokens_overrides_gemma3(self, temp_dir): + cfg = DictDefault( + { + # use with tokenizer that has reserved_tokens in added_tokens + "tokenizer_config": "mlx-community/gemma-3-4b-it-8bit", + "added_tokens_overrides": { + 256001: "RANDOM_OVERRIDE_1", + 256002: "RANDOM_OVERRIDE_2", + }, + "output_dir": temp_dir, + } + ) + + tokenizer = load_tokenizer(cfg) + assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [ + 256001 + ] + assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [ + 256002 + ] + assert ( + tokenizer.decode([256001, 256002]) == "RANDOM_OVERRIDE_1RANDOM_OVERRIDE_2" + ) @enable_hf_offline def test_added_tokens_overrides_with_toolargeid(self, temp_dir):