fix tokenizer overrides w gemma3 (#2488)
* fix tokenizer overrides w gemma3 * fix offline wrapping
This commit is contained in:
@@ -283,6 +283,13 @@ def modify_tokenizer_files(
|
||||
raise ValueError(
|
||||
f"Token ID {token_id} not found in added_tokens"
|
||||
)
|
||||
if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]:
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
for entry_val, entry_id in tokenizer_data["model"]["vocab"].items():
|
||||
if entry_id == token_id:
|
||||
del tokenizer_data["model"]["vocab"][entry_val]
|
||||
tokenizer_data["model"]["vocab"][new_value] = token_id
|
||||
break
|
||||
|
||||
# Write the updated tokenizer data back
|
||||
with open(tokenizer_path, "w", encoding="utf-8") as f:
|
||||
|
||||
@@ -19,7 +19,6 @@ from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.hf_offline_utils import (
|
||||
disable_hf_offline,
|
||||
enable_hf_offline,
|
||||
hf_offline_context,
|
||||
)
|
||||
@@ -50,7 +49,6 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||
|
||||
|
||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||
@disable_hf_offline
|
||||
def snapshot_download_w_retry(*args, **kwargs):
|
||||
"""
|
||||
download a model or dataset from HF Hub, retrying in requests failures. We also try to fetch it from the local
|
||||
@@ -62,7 +60,8 @@ def snapshot_download_w_retry(*args, **kwargs):
|
||||
return snapshot_download(*args, **kwargs)
|
||||
except LocalEntryNotFoundError:
|
||||
pass
|
||||
return snapshot_download(*args, **kwargs)
|
||||
with hf_offline_context(False):
|
||||
return snapshot_download(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@@ -265,6 +264,16 @@ def download_mistral_7b_model_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_gemma3_4b_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"mlx-community/gemma-3-4b-it-8bit",
|
||||
repo_type="model",
|
||||
allow_patterns=["*token*", "config.json"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_gemma_2b_model_fixture():
|
||||
# download the tokenizer only
|
||||
|
||||
@@ -95,7 +95,7 @@ def hf_offline_context(hf_hub_offline):
|
||||
"""
|
||||
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
|
||||
os.environ["HF_HUB_OFFLINE"] = str(hf_hub_offline)
|
||||
reload_modules(True)
|
||||
reload_modules(bool(hf_hub_offline))
|
||||
yield
|
||||
# Restore the original value of HF_HUB_OFFLINE environment variable
|
||||
if original_hf_offline is not None:
|
||||
|
||||
@@ -110,6 +110,34 @@ class TestTokenizers:
|
||||
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
|
||||
128042
|
||||
]
|
||||
assert (
|
||||
tokenizer.decode([128041, 128042]) == "RANDOM_OVERRIDE_1RANDOM_OVERRIDE_2"
|
||||
)
|
||||
|
||||
@enable_hf_offline
|
||||
def test_added_tokens_overrides_gemma3(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
# use with tokenizer that has reserved_tokens in added_tokens
|
||||
"tokenizer_config": "mlx-community/gemma-3-4b-it-8bit",
|
||||
"added_tokens_overrides": {
|
||||
256001: "RANDOM_OVERRIDE_1",
|
||||
256002: "RANDOM_OVERRIDE_2",
|
||||
},
|
||||
"output_dir": temp_dir,
|
||||
}
|
||||
)
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
|
||||
256001
|
||||
]
|
||||
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
|
||||
256002
|
||||
]
|
||||
assert (
|
||||
tokenizer.decode([256001, 256002]) == "RANDOM_OVERRIDE_1RANDOM_OVERRIDE_2"
|
||||
)
|
||||
|
||||
@enable_hf_offline
|
||||
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
|
||||
|
||||
Reference in New Issue
Block a user