Compare commits

...

9 Commits

Author SHA1 Message Date
Dan Saunders
4f39aeefb9 debug 2025-06-09 20:38:46 +00:00
Dan Saunders
8f75136ad3 debug 2025-06-09 20:38:13 +00:00
Dan Saunders
70e9cb545d update mistral dep version 2025-06-09 18:03:45 +00:00
Dan Saunders
aa236a4669 use from_hf_hub 2025-06-09 18:03:43 +00:00
Dan Saunders
65f8988efd small changes 2025-06-09 18:03:31 +00:00
Dan Saunders
13ddb8f172 Simplify mistral tokenizer identification (depends on upstream PR) 2025-06-09 18:03:31 +00:00
Dan Saunders
b1570ed0fa update 2025-06-09 18:03:31 +00:00
Dan Saunders
9581a9efed refactor tokenizer loader + add mistral logic 2025-06-09 18:03:28 +00:00
Dan Saunders
7e44445494 add mistral-common dep 2025-06-09 18:02:28 +00:00
7 changed files with 663 additions and 236 deletions

View File

@@ -20,6 +20,7 @@ datasets==3.6.0
deepspeed>=0.17.0
trl==0.18.1
hf_xet==1.1.2
mistral-common[hf-hub]==1.6.0
optimum==1.16.2
hf_transfer

View File

@@ -305,8 +305,8 @@ def load_model_and_tokenizer(
ProcessorMixin | None,
]:
"""
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
config.
Helper function for loading a model, tokenizer, and processor specified in the
given `axolotl` config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.

View File

@@ -64,6 +64,10 @@ class TokenizedPromptDataset(Dataset):
desc="Strategy Filtering Rows",
)
import ipdb
ipdb.set_trace()
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,

View File

@@ -2,8 +2,16 @@
import json
import os
from typing import Any, Dict, List, Optional, Union
import torch
import transformers
from huggingface_hub import hf_hub_download
from mistral_common.protocol.instruct.messages import SystemMessage, UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer,
)
from transformers import (
AddedToken,
AutoTokenizer,
@@ -23,239 +31,622 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
# Constants
LLAMA_TOKENIZER_CLASSES = {
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
"CodeLlamaTokenizerFast",
}
def modify_tokenizer_files(
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
) -> str:
FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"}
QWEN_DEFAULT_TOKEN = "<|endoftext|>"
GPTNEOX_PAD_TOKEN = "[PAD]"
CHATML_DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
class MistralTokenizerWrapper:
"""
Modify tokenizer files to replace added_tokens strings, save to output directory,
and return the path to the modified tokenizer.
This only works with reserved tokens that were added to the tokenizer, not tokens
already part of the vocab.
Args:
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
Returns:
Path to the modified tokenizer directory
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
Wrapper to make MistralTokenizer compatible with Hugging Face tokenizer interface.
This provides a bridge between Mistral's native tokenizer and axolotl's expectations.
"""
# Create the tokenizer directory in output_dir if it doesn't exist
tokenizer_dir = os.path.join(output_dir, "tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
# Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
def __init__(self, mistral_tokenizer: "MistralTokenizer", model_id: str):
self.mistral_tokenizer = mistral_tokenizer
self.model_id = model_id
self._system_prompt = None
self.padding_side = "right" # Default padding side
self.chat_template = None
# Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir)
# Cache token IDs by inspecting the actual tokenizer
self._token_ids = self._discover_token_ids()
# Get the token IDs and map them to their new values
# Try to load system prompt if available
try:
self._system_prompt = self._load_system_prompt(
model_id, "SYSTEM_PROMPT.txt"
)
except Exception as e:
LOG.debug(f"Could not load system prompt: {e}")
def _discover_token_ids(self) -> Dict[str, int]:
"""Discover the actual token IDs used by this Mistral tokenizer."""
token_ids = {}
try:
if hasattr(self.mistral_tokenizer, "instruct_tokenizer"):
instruct_tokenizer = self.mistral_tokenizer.instruct_tokenizer
# Get BOS token ID from instruct_tokenizer
token_ids["bos_token_id"] = getattr(instruct_tokenizer, "BOS", 1)
# Get token IDs from the underlying Tekkenizer
if hasattr(instruct_tokenizer, "tokenizer"):
tekkenizer = instruct_tokenizer.tokenizer
# Get BOS ID from tekkenizer (should match instruct_tokenizer.BOS)
if hasattr(tekkenizer, "bos_id"):
token_ids["bos_token_id"] = tekkenizer.bos_id
# Get vocab size to help find EOS token
vocab_size = getattr(tekkenizer, "_vocab_size", None)
# Check special tokens
if hasattr(tekkenizer, "_all_special_tokens"):
special_tokens = tekkenizer._all_special_tokens
keys = (
list(special_tokens.keys())
if hasattr(special_tokens, "keys")
else special_tokens
)
LOG.debug(f"Special tokens available: {keys}")
# Try to find EOS token in special tokens
if hasattr(special_tokens, "get"):
# Common EOS token patterns
for eos_pattern in ["</s>", "<|endoftext|>", "eos", "EOS"]:
if eos_pattern in special_tokens:
token_ids["eos_token_id"] = special_tokens[
eos_pattern
]
break
# Check special tokens reverse vocab
if hasattr(tekkenizer, "_special_tokens_reverse_vocab"):
reverse_vocab = tekkenizer._special_tokens_reverse_vocab
LOG.debug(f"Reverse special tokens: {reverse_vocab}")
# Look for common special token IDs
for token_id, token_str in reverse_vocab.items():
if token_str in ["</s>", "<|endoftext|>"]:
token_ids["eos_token_id"] = token_id
elif token_str in ["<unk>", "<UNK>"]:
token_ids["unk_token_id"] = token_id
# If we have vocab_size, EOS is often vocab_size - 1 or similar
if "eos_token_id" not in token_ids and vocab_size:
# Common patterns: EOS could be 2, vocab_size-1, or other values
# Let's try a safer approach by checking what tokens decode to
for candidate_id in [2, vocab_size - 1, vocab_size - 2]:
try:
# Try to decode and see if it looks like EOS
decoded = tekkenizer.decode([candidate_id])
if decoded in ["</s>", "<|endoftext|>", ""]:
token_ids["eos_token_id"] = candidate_id
break
except Exception:
continue
except Exception as e:
LOG.debug(f"Could not discover token IDs: {e}")
# Set reasonable defaults for any missing token IDs
token_ids.setdefault("bos_token_id", 1)
token_ids.setdefault("eos_token_id", 2)
token_ids.setdefault("unk_token_id", 0)
token_ids.setdefault(
"pad_token_id", token_ids["eos_token_id"]
) # Use EOS as pad
LOG.info(f"Discovered Mistral token IDs: {token_ids}")
return token_ids
def _load_system_prompt(self, repo_id: str, filename: str) -> str:
"""Load system prompt from HuggingFace Hub"""
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(file_path, "r") as file:
return file.read()
def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> List[int]:
"""Encode text to token IDs"""
if isinstance(text, str):
# For simple string encoding, create a user message
messages = []
if self._system_prompt and add_special_tokens:
messages.append(SystemMessage(content=self._system_prompt))
messages.append(UserMessage(content=text))
tokenized = self.mistral_tokenizer.encode_chat_completion(
ChatCompletionRequest(messages=messages)
)
return tokenized.tokens
else:
raise ValueError("MistralTokenizer wrapper only supports string input")
def decode(
self,
token_ids: Union[List[int], torch.Tensor],
skip_special_tokens: bool = True,
) -> str:
"""Decode token IDs to text"""
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
return self.mistral_tokenizer.decode(token_ids)
def __call__(self, text: str, **kwargs):
"""Make the tokenizer callable like HF tokenizers"""
tokens = self.encode(text, **kwargs)
return {"input_ids": torch.tensor([tokens])}
@property
def eos_token_id(self):
return self._token_ids["eos_token_id"]
@property
def bos_token_id(self):
return self._token_ids["bos_token_id"]
@property
def pad_token_id(self):
return self._token_ids["pad_token_id"]
@property
def unk_token_id(self):
return self._token_ids["unk_token_id"]
@property
def eos_token(self):
return "</s>" # Standard Mistral EOS token
@property
def bos_token(self):
return "<s>" # Standard Mistral BOS token
@property
def pad_token(self):
return self.eos_token # Use EOS as pad token
@property
def unk_token(self):
return "<unk>" # Standard UNK token
@property
def __class__(self):
# Create a mock class for compatibility checks
class MistralTokenizerWrapperClass:
__name__ = "MistralTokenizerWrapper"
return MistralTokenizerWrapperClass
def add_special_tokens(self, special_tokens_dict: Dict[str, str]) -> int:
"""Placeholder for special token addition - Mistral tokenizer handles this internally"""
LOG.warning(
"add_special_tokens called on MistralTokenizer wrapper - this is handled internally"
)
return 0
def add_tokens(self, tokens) -> int:
"""Placeholder for token addition - Mistral tokenizer handles this internally"""
LOG.warning(
"add_tokens called on MistralTokenizer wrapper - this is handled internally"
)
return 0
class TokenizerFileModifier:
"""Handles modification of tokenizer files for token overrides."""
def __init__(
self, tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
):
self.tokenizer_path = tokenizer_path
self.token_mappings = token_mappings
self.output_dir = output_dir
self.tokenizer_dir = os.path.join(output_dir, "tokenizer")
def modify_and_save(self) -> str:
"""Modify tokenizer files and return path to modified tokenizer."""
os.makedirs(self.tokenizer_dir, exist_ok=True)
if is_local_main_process():
self._perform_modifications()
barrier()
return self.tokenizer_dir
def _perform_modifications(self):
"""Perform the actual file modifications."""
# Load and save tokenizer to output directory
temp_tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_path, use_fast=True
)
temp_tokenizer.save_pretrained(self.tokenizer_dir)
# Convert token mappings to proper format
token_id_mappings = {
int(token_id): new_value for token_id, new_value in token_mappings.items()
int(token_id): new_value
for token_id, new_value in self.token_mappings.items()
}
# 1. Update tokenizer_config.json - added_tokens_decoder
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
# Update both tokenizer files
self._update_tokenizer_config(token_id_mappings)
self._update_tokenizer_json(token_id_mappings)
# Update added_tokens_decoder
if "added_tokens_decoder" in config_data:
for token_id, new_value in token_id_mappings.items():
token_id_str = str(token_id)
if token_id_str in config_data["added_tokens_decoder"]:
config_data["added_tokens_decoder"][token_id_str][
"content"
] = new_value
else:
raise ValueError(
f"Token ID {token_id_str} not found in added_tokens_decoder"
)
def _update_tokenizer_config(self, token_id_mappings: Dict[int, str]):
"""Update tokenizer_config.json with new token mappings."""
config_path = os.path.join(self.tokenizer_dir, "tokenizer_config.json")
if not os.path.exists(config_path):
return
# Write the updated config back
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
with open(config_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
# 2. Update tokenizer.json - added_tokens
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
if os.path.exists(tokenizer_path):
with open(tokenizer_path, "r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
if "added_tokens_decoder" in config_data:
self._update_added_tokens_decoder(config_data, token_id_mappings)
# Update added_tokens
if "added_tokens" in tokenizer_data:
for token_id, new_value in token_id_mappings.items():
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
if token_entry["id"] == token_id:
tokenizer_data["added_tokens"][i]["content"] = new_value
break
else:
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
raise ValueError(
f"Token ID {token_id} not found in added_tokens"
)
if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]:
for token_id, new_value in token_id_mappings.items():
for entry_val, entry_id in tokenizer_data["model"]["vocab"].items():
if entry_id == token_id:
del tokenizer_data["model"]["vocab"][entry_val]
tokenizer_data["model"]["vocab"][new_value] = token_id
break
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
# Write the updated tokenizer data back
with open(tokenizer_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_data, f, indent=2)
def _update_added_tokens_decoder(
self, config_data: Dict, token_id_mappings: Dict[int, str]
):
"""Update the added_tokens_decoder section."""
for token_id, new_value in token_id_mappings.items():
token_id_str = str(token_id)
if token_id_str in config_data["added_tokens_decoder"]:
config_data["added_tokens_decoder"][token_id_str]["content"] = new_value
else:
raise ValueError(
f"Token ID {token_id_str} not found in added_tokens_decoder"
)
barrier()
return tokenizer_dir
def _update_tokenizer_json(self, token_id_mappings: Dict[int, str]):
"""Update tokenizer.json with new token mappings."""
tokenizer_json_path = os.path.join(self.tokenizer_dir, "tokenizer.json")
if not os.path.exists(tokenizer_json_path):
return
with open(tokenizer_json_path, "r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
self._update_added_tokens_list(tokenizer_data, token_id_mappings)
self._update_vocab_mappings(tokenizer_data, token_id_mappings)
with open(tokenizer_json_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_data, f, indent=2)
def _update_added_tokens_list(
self, tokenizer_data: Dict, token_id_mappings: Dict[int, str]
):
"""Update the added_tokens list in tokenizer.json."""
if "added_tokens" not in tokenizer_data:
return
for token_id, new_value in token_id_mappings.items():
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
if token_entry["id"] == token_id:
tokenizer_data["added_tokens"][i]["content"] = new_value
break
else:
raise ValueError(f"Token ID {token_id} not found in added_tokens")
def _update_vocab_mappings(
self, tokenizer_data: Dict, token_id_mappings: Dict[int, str]
):
"""Update vocab mappings in tokenizer.json."""
if not (tokenizer_data.get("model") and tokenizer_data["model"].get("vocab")):
return
vocab = tokenizer_data["model"]["vocab"]
for token_id, new_value in token_id_mappings.items():
# Find and update the vocab entry
for entry_val, entry_id in list(vocab.items()):
if entry_id == token_id:
del vocab[entry_val]
vocab[new_value] = token_id
break
class TokenizerConfiguration:
"""Handles tokenizer configuration and initialization."""
def __init__(self, cfg):
self.cfg = cfg
self.model_config = load_model_config(cfg)
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
"""Load Mistral tokenizer from model configuration."""
# Instantiate Mistral tokenizer
model_id = self.cfg.base_model
mistral_tokenizer = MistralTokenizer.from_hf_hub(model_id)
# Wrap it for compatibility
tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
return tokenizer
def get_tokenizer_class(self):
"""Get the appropriate tokenizer class."""
if self.cfg.tokenizer_type:
return getattr(transformers, self.cfg.tokenizer_type)
return AutoTokenizer
def get_tokenizer_kwargs(self) -> Dict[str, Any]:
"""Build tokenizer initialization kwargs."""
kwargs = {}
if self.cfg.tokenizer_legacy is not None:
kwargs["legacy"] = self.cfg.tokenizer_legacy
return kwargs
def get_tokenizer_path(self) -> str:
"""Get the tokenizer path, applying overrides if needed."""
tokenizer_path = self.cfg.tokenizer_config
if self.cfg.added_tokens_overrides:
modifier = TokenizerFileModifier(
tokenizer_path, self.cfg.added_tokens_overrides, self.cfg.output_dir
)
tokenizer_path = modifier.modify_and_save()
return tokenizer_path
def should_use_fast_tokenizer(self) -> bool:
"""Determine if fast tokenizer should be used."""
return (
self.cfg.tokenizer_use_fast
if self.cfg.tokenizer_use_fast is not None
else True
)
class TokenizerPostProcessor:
"""Handles post-processing configuration of loaded tokenizers."""
def __init__(self, tokenizer, cfg):
self.tokenizer = tokenizer
self.cfg = cfg
self.model_config = load_model_config(cfg)
def apply_all_configurations(self):
"""Apply all post-processing configurations to the tokenizer."""
# Skip most configurations for Mistral wrapper
if isinstance(self.tokenizer, MistralTokenizerWrapper):
self._configure_mistral_wrapper()
return
self._configure_padding_token()
self._configure_gptneox_settings()
self._configure_mistral_padding()
self._configure_qwen_tokens()
self._add_special_tokens()
self._add_regular_tokens()
self._configure_chat_template()
def _configure_mistral_wrapper(self):
"""Apply limited configurations for Mistral wrapper."""
# Set padding side if needed
if (
self.cfg.is_mistral_derived_model
and self.cfg.flash_attention
and not self.cfg.sample_packing
):
self.tokenizer.padding_side = "left"
# Configure chat template for Mistral
self._configure_chat_template()
def _configure_padding_token(self):
"""Configure padding token for Llama-based tokenizers."""
if (
self.tokenizer.__class__.__name__ in LLAMA_TOKENIZER_CLASSES
and hasattr(self.tokenizer, "pad_token")
and not self.tokenizer.pad_token
):
self.tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
def _configure_gptneox_settings(self):
"""Configure GPTNeoX-specific settings."""
if self.tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
self.tokenizer.add_special_tokens({"pad_token": GPTNEOX_PAD_TOKEN})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def _configure_mistral_padding(self):
"""Configure left padding for Mistral models with Flash Attention."""
if (
self.cfg.is_mistral_derived_model
and self.cfg.flash_attention
and not self.cfg.sample_packing
):
self.tokenizer.padding_side = "left"
def _configure_qwen_tokens(self):
"""Configure special tokens for Qwen models."""
if not self.cfg.is_qwen_derived_model:
return
# Set token IDs
token_id_attributes = [
"bos_token_id",
"eos_token_id",
"pad_token_id",
"unk_token_id",
]
for attr_name in token_id_attributes:
if getattr(self.tokenizer, attr_name) is None:
setattr(self.tokenizer, attr_name, self.tokenizer.eod_id)
# Set token strings
token_name_attributes = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_name_attributes:
if getattr(self.tokenizer, attr_name) is None:
setattr(self.tokenizer, attr_name, QWEN_DEFAULT_TOKEN)
def _add_special_tokens(self):
"""Add special tokens from configuration."""
if not self.cfg.special_tokens:
return
special_tokens_dict = self.cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens_dict.pop(
"additional_special_tokens", None
)
self._validate_and_add_special_tokens(special_tokens_dict)
self._update_post_processor_if_needed(special_tokens_dict)
self._add_additional_special_tokens_if_present(additional_special_tokens)
def _validate_and_add_special_tokens(self, special_tokens: Dict[str, str]):
"""Validate special tokens for adapter training and add them."""
lora_modules_to_save = get_linear_embedding_layers(self.model_config.model_type)
for key, value in special_tokens.items():
self._validate_token_for_adapter(key, value, lora_modules_to_save)
self.tokenizer.add_special_tokens(
{key: AddedToken(value, rstrip=False, lstrip=False, normalized=False)}
)
def _validate_token_for_adapter(
self, key: str, value: str, lora_modules_to_save: List[str]
):
"""Validate a single token for adapter training requirements."""
if not self._should_validate_token_for_adapter(
key, value, lora_modules_to_save
):
return
modules_str = ", ".join(f"`{x}`" for x in lora_modules_to_save)
raise ValueError(
f"Please set lora_modules_to_save to [{modules_str}] "
f"when using an adapter and changing the special tokens."
)
def _should_validate_token_for_adapter(
self, key: str, value: str, lora_modules_to_save: List[str]
) -> bool:
"""Check if token should be validated for adapter configuration."""
if key == "pad_token" or not self.cfg.adapter:
return False
current_token = getattr(self.tokenizer, key)
token_changed = current_token is None or current_token != value
token_is_multi_char = (
len(self.tokenizer.encode(value, add_special_tokens=False)) > 2
)
lora_modules_missing = not self.cfg.lora_modules_to_save or not all(
x in self.cfg.lora_modules_to_save for x in lora_modules_to_save
)
return token_changed and token_is_multi_char and lora_modules_missing
def _update_post_processor_if_needed(self, special_tokens: Dict[str, str]):
"""Update post processor for Llama tokenizers when BOS/EOS tokens are added."""
has_bos_and_eos = (
"bos_token" in special_tokens and "eos_token" in special_tokens
)
is_fast_llama = (
self.tokenizer.__class__.__name__ in FAST_LLAMA_TOKENIZER_CLASSES
)
if is_fast_llama and has_bos_and_eos:
self.tokenizer.update_post_processor()
def _add_additional_special_tokens_if_present(
self, additional_special_tokens: Optional[List[str]]
):
"""Add additional special tokens if they exist."""
if additional_special_tokens is not None:
self.tokenizer.add_special_tokens(
{"additional_special_tokens": additional_special_tokens}
)
def _add_regular_tokens(self):
"""Add regular (non-special) tokens from configuration."""
if self.cfg.tokens:
self.tokenizer.add_tokens(
[
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
for token in self.cfg.tokens
]
)
def _configure_chat_template(self):
"""Configure chat template if specified."""
if not self.cfg.chat_template:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)
return
chat_template_string = get_chat_template_from_config(
cfg=self.cfg,
tokenizer=self.tokenizer,
)
if self._should_replace_default_system_message():
chat_template_string = chat_template_string.replace(
CHATML_DEFAULT_SYSTEM_MESSAGE, self.cfg.default_system_message
)
self.tokenizer.chat_template = chat_template_string
def _should_replace_default_system_message(self) -> bool:
"""Check if default system message should be replaced."""
return self.cfg.default_system_message and self.cfg.chat_template == "chatml"
def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
"""Load and configure the tokenizer based on the provided config.
if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
This function handles the complete tokenizer loading pipeline:
- Check if Mistral tokenizer should be used
- Configure tokenizer parameters and get the appropriate class
- Handle token file modifications if needed
- Initialize the tokenizer with the correct parameters
- Apply all post-processing configurations (padding, special tokens, etc.)
- Set up chat templates and logging
tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
# Set base tokenizer path
tokenizer_path = cfg.tokenizer_config
Returns:
Fully configured tokenizer instance.
"""
# Configure tokenizer parameters
config = TokenizerConfiguration(cfg)
# Apply token string overrides if specified
if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer
tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
# Check if we should use Mistral tokenizer
try:
tokenizer = config.load_mistral_tokenizer()
except:
# Standard tokenizer loading
tokenizer_cls = config.get_tokenizer_class()
tokenizer_path = config.get_tokenizer_path()
use_fast = config.should_use_fast_tokenizer()
tokenizer_kwargs = config.get_tokenizer_kwargs()
# Initialize the tokenizer
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_path,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_path,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
if (
tokenizer.__class__.__name__
in [
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
"CodeLlamaTokenizerFast",
]
and hasattr(tokenizer, "pad_token")
and not tokenizer.pad_token
):
# set a pad_token, but use eos_token so we don't add a new token
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Mistral's official FA implementation requires left padding
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens
if cfg.is_qwen_derived_model:
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, tokenizer.eod_id)
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_names:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens.pop(
"additional_special_tokens", None
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
for k, val in special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
# pylint: disable=too-many-boolean-expressions
if (
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
and cfg.adapter
and (
not cfg.lora_modules_to_save
or not all(
x in cfg.lora_modules_to_save for x in lora_modules_to_save
)
)
and k != "pad_token"
):
lora_modules_to_save = ", ".join(
[f"`{x}`" for x in lora_modules_to_save]
)
raise ValueError(
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
)
tokenizer.add_special_tokens(
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
)
# If we add bos_token and eos_token, we need to update the post processor to
# handle them correctly.
# https://github.com/huggingface/transformers/pull/24132
bos_or_eos_in_special_tokens = (
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
)
if (
tokenizer.__class__.__name__
in (
"LlamaTokenizerFast",
"CodeLlamaTokenizerFast",
)
and bos_or_eos_in_special_tokens
):
tokenizer.update_post_processor()
if cfg.tokens:
tokenizer.add_tokens(
[
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
for token in cfg.tokens
]
)
# Additional special tokens are a List, and need to be treated differently than regular special
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
# are new tokens.
#
# Usage:
#
# ```py
# special_tokens:
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
# ```
if additional_special_tokens is not None:
tokenizer.add_special_tokens(
{"additional_special_tokens": additional_special_tokens}
)
# Apply all post-processing configurations
post_processor = TokenizerPostProcessor(tokenizer, cfg)
post_processor.apply_all_configurations()
if is_main_process(use_environ=True):
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
@@ -263,19 +654,4 @@ def load_tokenizer(cfg):
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if cfg.chat_template:
chat_template_string = get_chat_template_from_config(
cfg=cfg,
tokenizer=tokenizer,
)
if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message
)
tokenizer.chat_template = chat_template_string
else:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)
return tokenizer

View File

@@ -67,6 +67,10 @@ class PromptTokenizingStrategy(abc.ABC):
LOG.warning("Empty text requested for tokenization.")
return empty
import ipdb
ipdb.set_trace()
result = self.tokenizer(
prompt,
truncation=True,

View File

@@ -486,6 +486,10 @@ def get_dataset_wrapper(
f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
)
import ipdb
ipdb.set_trace()
if (
isinstance(dataset, Dataset)
and "input_ids" in dataset.features

View File

@@ -1,8 +1,4 @@
"""
Test cases for the tokenizer loading
"""
import unittest
"""Test cases for tokenizer loading."""
import pytest
@@ -13,9 +9,7 @@ from tests.hf_offline_utils import enable_hf_offline
class TestTokenizers:
"""
test class for the load_tokenizer fn
"""
"""Test class for the load_tokenizer fn"""
@enable_hf_offline
def test_default_use_fast(self):
@@ -155,6 +149,50 @@ class TestTokenizers:
):
load_tokenizer(cfg)
def test_mistral_tokenizer_auto_detection(self):
"""Test that Mistral models are auto-detected and use MistralTokenizerWrapper"""
cfg = DictDefault(
{
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
if __name__ == "__main__":
unittest.main()
def test_mixtral_tokenizer_auto_detection(self):
"""Test that Mixtral models are auto-detected and use MistralTokenizerWrapper"""
cfg = DictDefault(
{
"base_model": "model-hub/Mixtral-8x7B-v0.1",
"tokenizer_config": "model-hub/Mixtral-8x7B-v0.1",
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
def test_mistral_tokenizer_basic_functionality(self):
"""Test basic encode/decode functionality of MistralTokenizerWrapper"""
cfg = DictDefault(
{
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
}
)
tokenizer = load_tokenizer(cfg)
# Test basic encoding
text = "Hello, world!"
tokens = tokenizer.encode(text)
assert isinstance(tokens, list)
assert len(tokens) > 0
# Test basic decoding
decoded = tokenizer.decode(tokens)
assert isinstance(decoded, str)
# Test token properties are accessible
assert hasattr(tokenizer, "eos_token_id")
assert hasattr(tokenizer, "bos_token_id")
assert isinstance(tokenizer.eos_token_id, int)
assert isinstance(tokenizer.bos_token_id, int)