Update Tokenizer Overrides Handling in models.py (#1549)
* override special tokens mock code * fix(doc): remove duplicate config * feat: replace added_tokens in tokenizer and add test * make sure to run tokenizer modification on rank 0 only * use is local main process instead * feat: rename config --------- Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -154,8 +154,6 @@ datasets:
|
|||||||
content: value
|
content: value
|
||||||
# ...
|
# ...
|
||||||
|
|
||||||
message_property_mappings:
|
|
||||||
|
|
||||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||||
roles:
|
roles:
|
||||||
user: ["human", "user"]
|
user: ["human", "user"]
|
||||||
@@ -556,6 +554,13 @@ special_tokens:
|
|||||||
# Add extra tokens.
|
# Add extra tokens.
|
||||||
tokens:
|
tokens:
|
||||||
|
|
||||||
|
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
|
||||||
|
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
|
||||||
|
# Can be checked if they exist in tokenizer.json added_tokens.
|
||||||
|
added_tokens_overrides: # Dict[int, str]
|
||||||
|
# 128041: "<|im_start|>"
|
||||||
|
# 128042: "<|im_end|>"
|
||||||
|
|
||||||
# FSDP
|
# FSDP
|
||||||
fsdp:
|
fsdp:
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
|
|||||||
@@ -855,6 +855,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
special_tokens: Optional[SpecialTokensConfig] = None
|
special_tokens: Optional[SpecialTokensConfig] = None
|
||||||
tokens: Optional[List[str]] = None
|
tokens: Optional[List[str]] = None
|
||||||
|
added_tokens_overrides: Optional[Dict[int, str]] = None
|
||||||
|
|
||||||
torch_compile: Optional[Union[Literal["auto"], bool]] = None
|
torch_compile: Optional[Union[Literal["auto"], bool]] = None
|
||||||
torch_compile_backend: Optional[str] = None
|
torch_compile_backend: Optional[str] = None
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ def is_main_process():
|
|||||||
|
|
||||||
|
|
||||||
def is_local_main_process():
|
def is_local_main_process():
|
||||||
return PartialState().is_main_process
|
return PartialState().is_local_main_process
|
||||||
|
|
||||||
|
|
||||||
def get_world_size():
|
def get_world_size():
|
||||||
|
|||||||
@@ -57,7 +57,13 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
|||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
|
from axolotl.utils.distributed import (
|
||||||
|
barrier,
|
||||||
|
get_device_count,
|
||||||
|
get_device_type,
|
||||||
|
is_local_main_process,
|
||||||
|
zero_only,
|
||||||
|
)
|
||||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
@@ -165,7 +171,95 @@ def load_model_config(cfg):
|
|||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|
||||||
|
def modify_tokenizer_files(
|
||||||
|
tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer.
|
||||||
|
|
||||||
|
This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer_path: Path or name of the original tokenizer
|
||||||
|
token_mappings: Dict mapping {token_id (int): new_token_string}
|
||||||
|
output_dir: Directory to save the modified tokenizer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the modified tokenizer directory
|
||||||
|
|
||||||
|
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Create the tokenizer directory in output_dir if it doesn't exist
|
||||||
|
tokenizer_dir = os.path.join(output_dir, "tokenizer")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
|
||||||
|
# Load the tokenizer
|
||||||
|
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
|
||||||
|
|
||||||
|
# Save the tokenizer to the output directory
|
||||||
|
temp_tokenizer.save_pretrained(tokenizer_dir)
|
||||||
|
|
||||||
|
# Get the token IDs and map them to their new values
|
||||||
|
token_id_mappings = {
|
||||||
|
int(token_id): new_value for token_id, new_value in token_mappings.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# 1. Update tokenizer_config.json - added_tokens_decoder
|
||||||
|
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
config_data = json.load(f)
|
||||||
|
|
||||||
|
# Update added_tokens_decoder
|
||||||
|
if "added_tokens_decoder" in config_data:
|
||||||
|
for token_id, new_value in token_id_mappings.items():
|
||||||
|
token_id_str = str(token_id)
|
||||||
|
if token_id_str in config_data["added_tokens_decoder"]:
|
||||||
|
config_data["added_tokens_decoder"][token_id_str][
|
||||||
|
"content"
|
||||||
|
] = new_value
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Token ID {token_id_str} not found in added_tokens_decoder"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write the updated config back
|
||||||
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(config_data, f, indent=2)
|
||||||
|
|
||||||
|
# 2. Update tokenizer.json - added_tokens
|
||||||
|
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
||||||
|
if os.path.exists(tokenizer_path):
|
||||||
|
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
||||||
|
tokenizer_data = json.load(f)
|
||||||
|
|
||||||
|
# Update added_tokens
|
||||||
|
if "added_tokens" in tokenizer_data:
|
||||||
|
for token_id, new_value in token_id_mappings.items():
|
||||||
|
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
|
||||||
|
if token_entry["id"] == token_id:
|
||||||
|
tokenizer_data["added_tokens"][i]["content"] = new_value
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
|
||||||
|
raise ValueError(
|
||||||
|
f"Token ID {token_id} not found in added_tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write the updated tokenizer data back
|
||||||
|
with open(tokenizer_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(tokenizer_data, f, indent=2)
|
||||||
|
|
||||||
|
barrier()
|
||||||
|
return tokenizer_dir
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(cfg):
|
def load_tokenizer(cfg):
|
||||||
|
"""Load and configure the tokenizer based on the provided config."""
|
||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
tokenizer_kwargs = {}
|
tokenizer_kwargs = {}
|
||||||
use_fast = True # this is the default
|
use_fast = True # this is the default
|
||||||
@@ -180,8 +274,18 @@ def load_tokenizer(cfg):
|
|||||||
if cfg.tokenizer_type:
|
if cfg.tokenizer_type:
|
||||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||||
|
|
||||||
|
# Set base tokenizer path
|
||||||
|
tokenizer_path = cfg.tokenizer_config
|
||||||
|
|
||||||
|
# Apply token string overrides if specified
|
||||||
|
if cfg.added_tokens_overrides:
|
||||||
|
# Modify tokenizer files and get path to modified tokenizer
|
||||||
|
tokenizer_path = modify_tokenizer_files(
|
||||||
|
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer = tokenizer_cls.from_pretrained(
|
tokenizer = tokenizer_cls.from_pretrained(
|
||||||
cfg.tokenizer_config,
|
tokenizer_path,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
use_fast=use_fast,
|
use_fast=use_fast,
|
||||||
**tokenizer_kwargs,
|
**tokenizer_kwargs,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Test cases for the tokenizer loading
|
Test cases for the tokenizer loading
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -9,7 +10,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
class TestTokenizers(unittest.TestCase):
|
class TestTokenizers:
|
||||||
"""
|
"""
|
||||||
test class for the load_tokenizer fn
|
test class for the load_tokenizer fn
|
||||||
"""
|
"""
|
||||||
@@ -75,12 +76,48 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
|
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
|
||||||
self.assertEqual(len(tokenizer), 32001)
|
assert len(tokenizer) == 32001
|
||||||
|
|
||||||
# ensure reloading the tokenizer again from cfg results in same vocab length
|
# ensure reloading the tokenizer again from cfg results in same vocab length
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
self.assertEqual(len(tokenizer), 32001)
|
assert len(tokenizer) == 32001
|
||||||
|
|
||||||
|
def test_added_tokens_overrides(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
# use with tokenizer that has reserved_tokens in added_tokens
|
||||||
|
"tokenizer_config": "NousResearch/Llama-3.2-1B",
|
||||||
|
"added_tokens_overrides": {
|
||||||
|
128041: "RANDOM_OVERRIDE_1",
|
||||||
|
128042: "RANDOM_OVERRIDE_2",
|
||||||
|
},
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
|
||||||
|
128041
|
||||||
|
]
|
||||||
|
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
|
||||||
|
128042
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
# use with tokenizer that has reserved_tokens in added_tokens
|
||||||
|
"tokenizer_config": "NousResearch/Llama-3.2-1B",
|
||||||
|
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match=r".*Token ID 1000000 not found in added_tokens.*"
|
||||||
|
):
|
||||||
|
load_tokenizer(cfg)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user