Update Tokenizer Overrides Handling in models.py (#1549)
* override special tokens mock code * fix(doc): remove duplicate config * feat: replace added_tokens in tokenizer and add test * make sure to run tokenizer modification on rank 0 only * use is local main process instead * feat: rename config --------- Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -154,8 +154,6 @@ datasets:
|
||||
content: value
|
||||
# ...
|
||||
|
||||
message_property_mappings:
|
||||
|
||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||
roles:
|
||||
user: ["human", "user"]
|
||||
@@ -556,6 +554,13 @@ special_tokens:
|
||||
# Add extra tokens.
|
||||
tokens:
|
||||
|
||||
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
|
||||
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
|
||||
# Can be checked if they exist in tokenizer.json added_tokens.
|
||||
added_tokens_overrides: # Dict[int, str]
|
||||
# 128041: "<|im_start|>"
|
||||
# 128042: "<|im_end|>"
|
||||
|
||||
# FSDP
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
||||
@@ -855,6 +855,7 @@ class AxolotlInputConfig(
|
||||
|
||||
special_tokens: Optional[SpecialTokensConfig] = None
|
||||
tokens: Optional[List[str]] = None
|
||||
added_tokens_overrides: Optional[Dict[int, str]] = None
|
||||
|
||||
torch_compile: Optional[Union[Literal["auto"], bool]] = None
|
||||
torch_compile_backend: Optional[str] = None
|
||||
|
||||
@@ -79,7 +79,7 @@ def is_main_process():
|
||||
|
||||
|
||||
def is_local_main_process():
|
||||
return PartialState().is_main_process
|
||||
return PartialState().is_local_main_process
|
||||
|
||||
|
||||
def get_world_size():
|
||||
|
||||
@@ -57,7 +57,13 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
|
||||
from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
get_device_count,
|
||||
get_device_type,
|
||||
is_local_main_process,
|
||||
zero_only,
|
||||
)
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
@@ -165,7 +171,95 @@ def load_model_config(cfg):
|
||||
return model_config
|
||||
|
||||
|
||||
def modify_tokenizer_files(
|
||||
tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
|
||||
) -> str:
|
||||
"""
|
||||
Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer.
|
||||
|
||||
This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab.
|
||||
|
||||
Args:
|
||||
tokenizer_path: Path or name of the original tokenizer
|
||||
token_mappings: Dict mapping {token_id (int): new_token_string}
|
||||
output_dir: Directory to save the modified tokenizer
|
||||
|
||||
Returns:
|
||||
Path to the modified tokenizer directory
|
||||
|
||||
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
# Create the tokenizer directory in output_dir if it doesn't exist
|
||||
tokenizer_dir = os.path.join(output_dir, "tokenizer")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
|
||||
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
|
||||
# Load the tokenizer
|
||||
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
|
||||
|
||||
# Save the tokenizer to the output directory
|
||||
temp_tokenizer.save_pretrained(tokenizer_dir)
|
||||
|
||||
# Get the token IDs and map them to their new values
|
||||
token_id_mappings = {
|
||||
int(token_id): new_value for token_id, new_value in token_mappings.items()
|
||||
}
|
||||
|
||||
# 1. Update tokenizer_config.json - added_tokens_decoder
|
||||
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
# Update added_tokens_decoder
|
||||
if "added_tokens_decoder" in config_data:
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
token_id_str = str(token_id)
|
||||
if token_id_str in config_data["added_tokens_decoder"]:
|
||||
config_data["added_tokens_decoder"][token_id_str][
|
||||
"content"
|
||||
] = new_value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Token ID {token_id_str} not found in added_tokens_decoder"
|
||||
)
|
||||
|
||||
# Write the updated config back
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
|
||||
# 2. Update tokenizer.json - added_tokens
|
||||
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
||||
if os.path.exists(tokenizer_path):
|
||||
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
||||
tokenizer_data = json.load(f)
|
||||
|
||||
# Update added_tokens
|
||||
if "added_tokens" in tokenizer_data:
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
|
||||
if token_entry["id"] == token_id:
|
||||
tokenizer_data["added_tokens"][i]["content"] = new_value
|
||||
break
|
||||
else:
|
||||
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
|
||||
raise ValueError(
|
||||
f"Token ID {token_id} not found in added_tokens"
|
||||
)
|
||||
|
||||
# Write the updated tokenizer data back
|
||||
with open(tokenizer_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokenizer_data, f, indent=2)
|
||||
|
||||
barrier()
|
||||
return tokenizer_dir
|
||||
|
||||
|
||||
def load_tokenizer(cfg):
|
||||
"""Load and configure the tokenizer based on the provided config."""
|
||||
model_config = load_model_config(cfg)
|
||||
tokenizer_kwargs = {}
|
||||
use_fast = True # this is the default
|
||||
@@ -180,8 +274,18 @@ def load_tokenizer(cfg):
|
||||
if cfg.tokenizer_type:
|
||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||
|
||||
# Set base tokenizer path
|
||||
tokenizer_path = cfg.tokenizer_config
|
||||
|
||||
# Apply token string overrides if specified
|
||||
if cfg.added_tokens_overrides:
|
||||
# Modify tokenizer files and get path to modified tokenizer
|
||||
tokenizer_path = modify_tokenizer_files(
|
||||
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
|
||||
)
|
||||
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
cfg.tokenizer_config,
|
||||
tokenizer_path,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Test cases for the tokenizer loading
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
@@ -9,7 +10,7 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
|
||||
|
||||
class TestTokenizers(unittest.TestCase):
|
||||
class TestTokenizers:
|
||||
"""
|
||||
test class for the load_tokenizer fn
|
||||
"""
|
||||
@@ -75,12 +76,48 @@ class TestTokenizers(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
|
||||
self.assertEqual(len(tokenizer), 32001)
|
||||
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
|
||||
assert len(tokenizer) == 32001
|
||||
|
||||
# ensure reloading the tokenizer again from cfg results in same vocab length
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
self.assertEqual(len(tokenizer), 32001)
|
||||
assert len(tokenizer) == 32001
|
||||
|
||||
def test_added_tokens_overrides(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
# use with tokenizer that has reserved_tokens in added_tokens
|
||||
"tokenizer_config": "NousResearch/Llama-3.2-1B",
|
||||
"added_tokens_overrides": {
|
||||
128041: "RANDOM_OVERRIDE_1",
|
||||
128042: "RANDOM_OVERRIDE_2",
|
||||
},
|
||||
"output_dir": temp_dir,
|
||||
}
|
||||
)
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
|
||||
128041
|
||||
]
|
||||
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
|
||||
128042
|
||||
]
|
||||
|
||||
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
# use with tokenizer that has reserved_tokens in added_tokens
|
||||
"tokenizer_config": "NousResearch/Llama-3.2-1B",
|
||||
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
|
||||
"output_dir": temp_dir,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*Token ID 1000000 not found in added_tokens.*"
|
||||
):
|
||||
load_tokenizer(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user