Compare commits

...

9 Commits

Author SHA1 Message Date
Dan Saunders
4f39aeefb9 debug 2025-06-09 20:38:46 +00:00
Dan Saunders
8f75136ad3 debug 2025-06-09 20:38:13 +00:00
Dan Saunders
70e9cb545d update mistral dep version 2025-06-09 18:03:45 +00:00
Dan Saunders
aa236a4669 use from_hf_hub 2025-06-09 18:03:43 +00:00
Dan Saunders
65f8988efd small changes 2025-06-09 18:03:31 +00:00
Dan Saunders
13ddb8f172 Simplify mistral tokenizer identification (depends on upstream PR) 2025-06-09 18:03:31 +00:00
Dan Saunders
b1570ed0fa update 2025-06-09 18:03:31 +00:00
Dan Saunders
9581a9efed refactor tokenizer loader + add mistral logic 2025-06-09 18:03:28 +00:00
Dan Saunders
7e44445494 add mistral-common dep 2025-06-09 18:02:28 +00:00
7 changed files with 663 additions and 236 deletions

View File

@@ -20,6 +20,7 @@ datasets==3.6.0
deepspeed>=0.17.0 deepspeed>=0.17.0
trl==0.18.1 trl==0.18.1
hf_xet==1.1.2 hf_xet==1.1.2
mistral-common[hf-hub]==1.6.0
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer

View File

@@ -305,8 +305,8 @@ def load_model_and_tokenizer(
ProcessorMixin | None, ProcessorMixin | None,
]: ]:
""" """
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl` Helper function for loading a model, tokenizer, and processor specified in the
config. given `axolotl` config.
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.

View File

@@ -64,6 +64,10 @@ class TokenizedPromptDataset(Dataset):
desc="Strategy Filtering Rows", desc="Strategy Filtering Rows",
) )
import ipdb
ipdb.set_trace()
return dataset.map( return dataset.map(
self.prompt_tokenizer.tokenize_prompt, self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc, num_proc=num_proc,

View File

@@ -2,8 +2,16 @@
import json import json
import os import os
from typing import Any, Dict, List, Optional, Union
import torch
import transformers import transformers
from huggingface_hub import hf_hub_download
from mistral_common.protocol.instruct.messages import SystemMessage, UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer,
)
from transformers import ( from transformers import (
AddedToken, AddedToken,
AutoTokenizer, AutoTokenizer,
@@ -23,239 +31,622 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__) LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance() PLUGIN_MANAGER = PluginManager.get_instance()
# Constants
LLAMA_TOKENIZER_CLASSES = {
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
"CodeLlamaTokenizerFast",
}
def modify_tokenizer_files( FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"}
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
) -> str: QWEN_DEFAULT_TOKEN = "<|endoftext|>"
GPTNEOX_PAD_TOKEN = "[PAD]"
CHATML_DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
class MistralTokenizerWrapper:
""" """
Modify tokenizer files to replace added_tokens strings, save to output directory, Wrapper to make MistralTokenizer compatible with Hugging Face tokenizer interface.
and return the path to the modified tokenizer. This provides a bridge between Mistral's native tokenizer and axolotl's expectations.
This only works with reserved tokens that were added to the tokenizer, not tokens
already part of the vocab.
Args:
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
Returns:
Path to the modified tokenizer directory
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
""" """
# Create the tokenizer directory in output_dir if it doesn't exist
tokenizer_dir = os.path.join(output_dir, "tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
if is_local_main_process(): # pylint: disable=too-many-nested-blocks def __init__(self, mistral_tokenizer: "MistralTokenizer", model_id: str):
# Load the tokenizer self.mistral_tokenizer = mistral_tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) self.model_id = model_id
self._system_prompt = None
self.padding_side = "right" # Default padding side
self.chat_template = None
# Save the tokenizer to the output directory # Cache token IDs by inspecting the actual tokenizer
temp_tokenizer.save_pretrained(tokenizer_dir) self._token_ids = self._discover_token_ids()
# Get the token IDs and map them to their new values # Try to load system prompt if available
try:
self._system_prompt = self._load_system_prompt(
model_id, "SYSTEM_PROMPT.txt"
)
except Exception as e:
LOG.debug(f"Could not load system prompt: {e}")
def _discover_token_ids(self) -> Dict[str, int]:
"""Discover the actual token IDs used by this Mistral tokenizer."""
token_ids = {}
try:
if hasattr(self.mistral_tokenizer, "instruct_tokenizer"):
instruct_tokenizer = self.mistral_tokenizer.instruct_tokenizer
# Get BOS token ID from instruct_tokenizer
token_ids["bos_token_id"] = getattr(instruct_tokenizer, "BOS", 1)
# Get token IDs from the underlying Tekkenizer
if hasattr(instruct_tokenizer, "tokenizer"):
tekkenizer = instruct_tokenizer.tokenizer
# Get BOS ID from tekkenizer (should match instruct_tokenizer.BOS)
if hasattr(tekkenizer, "bos_id"):
token_ids["bos_token_id"] = tekkenizer.bos_id
# Get vocab size to help find EOS token
vocab_size = getattr(tekkenizer, "_vocab_size", None)
# Check special tokens
if hasattr(tekkenizer, "_all_special_tokens"):
special_tokens = tekkenizer._all_special_tokens
keys = (
list(special_tokens.keys())
if hasattr(special_tokens, "keys")
else special_tokens
)
LOG.debug(f"Special tokens available: {keys}")
# Try to find EOS token in special tokens
if hasattr(special_tokens, "get"):
# Common EOS token patterns
for eos_pattern in ["</s>", "<|endoftext|>", "eos", "EOS"]:
if eos_pattern in special_tokens:
token_ids["eos_token_id"] = special_tokens[
eos_pattern
]
break
# Check special tokens reverse vocab
if hasattr(tekkenizer, "_special_tokens_reverse_vocab"):
reverse_vocab = tekkenizer._special_tokens_reverse_vocab
LOG.debug(f"Reverse special tokens: {reverse_vocab}")
# Look for common special token IDs
for token_id, token_str in reverse_vocab.items():
if token_str in ["</s>", "<|endoftext|>"]:
token_ids["eos_token_id"] = token_id
elif token_str in ["<unk>", "<UNK>"]:
token_ids["unk_token_id"] = token_id
# If we have vocab_size, EOS is often vocab_size - 1 or similar
if "eos_token_id" not in token_ids and vocab_size:
# Common patterns: EOS could be 2, vocab_size-1, or other values
# Let's try a safer approach by checking what tokens decode to
for candidate_id in [2, vocab_size - 1, vocab_size - 2]:
try:
# Try to decode and see if it looks like EOS
decoded = tekkenizer.decode([candidate_id])
if decoded in ["</s>", "<|endoftext|>", ""]:
token_ids["eos_token_id"] = candidate_id
break
except Exception:
continue
except Exception as e:
LOG.debug(f"Could not discover token IDs: {e}")
# Set reasonable defaults for any missing token IDs
token_ids.setdefault("bos_token_id", 1)
token_ids.setdefault("eos_token_id", 2)
token_ids.setdefault("unk_token_id", 0)
token_ids.setdefault(
"pad_token_id", token_ids["eos_token_id"]
) # Use EOS as pad
LOG.info(f"Discovered Mistral token IDs: {token_ids}")
return token_ids
def _load_system_prompt(self, repo_id: str, filename: str) -> str:
"""Load system prompt from HuggingFace Hub"""
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(file_path, "r") as file:
return file.read()
def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> List[int]:
"""Encode text to token IDs"""
if isinstance(text, str):
# For simple string encoding, create a user message
messages = []
if self._system_prompt and add_special_tokens:
messages.append(SystemMessage(content=self._system_prompt))
messages.append(UserMessage(content=text))
tokenized = self.mistral_tokenizer.encode_chat_completion(
ChatCompletionRequest(messages=messages)
)
return tokenized.tokens
else:
raise ValueError("MistralTokenizer wrapper only supports string input")
def decode(
self,
token_ids: Union[List[int], torch.Tensor],
skip_special_tokens: bool = True,
) -> str:
"""Decode token IDs to text"""
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
return self.mistral_tokenizer.decode(token_ids)
def __call__(self, text: str, **kwargs):
"""Make the tokenizer callable like HF tokenizers"""
tokens = self.encode(text, **kwargs)
return {"input_ids": torch.tensor([tokens])}
@property
def eos_token_id(self):
return self._token_ids["eos_token_id"]
@property
def bos_token_id(self):
return self._token_ids["bos_token_id"]
@property
def pad_token_id(self):
return self._token_ids["pad_token_id"]
@property
def unk_token_id(self):
return self._token_ids["unk_token_id"]
@property
def eos_token(self):
return "</s>" # Standard Mistral EOS token
@property
def bos_token(self):
return "<s>" # Standard Mistral BOS token
@property
def pad_token(self):
return self.eos_token # Use EOS as pad token
@property
def unk_token(self):
return "<unk>" # Standard UNK token
@property
def __class__(self):
# Create a mock class for compatibility checks
class MistralTokenizerWrapperClass:
__name__ = "MistralTokenizerWrapper"
return MistralTokenizerWrapperClass
def add_special_tokens(self, special_tokens_dict: Dict[str, str]) -> int:
"""Placeholder for special token addition - Mistral tokenizer handles this internally"""
LOG.warning(
"add_special_tokens called on MistralTokenizer wrapper - this is handled internally"
)
return 0
def add_tokens(self, tokens) -> int:
"""Placeholder for token addition - Mistral tokenizer handles this internally"""
LOG.warning(
"add_tokens called on MistralTokenizer wrapper - this is handled internally"
)
return 0
class TokenizerFileModifier:
"""Handles modification of tokenizer files for token overrides."""
def __init__(
self, tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
):
self.tokenizer_path = tokenizer_path
self.token_mappings = token_mappings
self.output_dir = output_dir
self.tokenizer_dir = os.path.join(output_dir, "tokenizer")
def modify_and_save(self) -> str:
"""Modify tokenizer files and return path to modified tokenizer."""
os.makedirs(self.tokenizer_dir, exist_ok=True)
if is_local_main_process():
self._perform_modifications()
barrier()
return self.tokenizer_dir
def _perform_modifications(self):
"""Perform the actual file modifications."""
# Load and save tokenizer to output directory
temp_tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_path, use_fast=True
)
temp_tokenizer.save_pretrained(self.tokenizer_dir)
# Convert token mappings to proper format
token_id_mappings = { token_id_mappings = {
int(token_id): new_value for token_id, new_value in token_mappings.items() int(token_id): new_value
for token_id, new_value in self.token_mappings.items()
} }
# 1. Update tokenizer_config.json - added_tokens_decoder # Update both tokenizer files
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json") self._update_tokenizer_config(token_id_mappings)
if os.path.exists(config_path): self._update_tokenizer_json(token_id_mappings)
with open(config_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
# Update added_tokens_decoder def _update_tokenizer_config(self, token_id_mappings: Dict[int, str]):
if "added_tokens_decoder" in config_data: """Update tokenizer_config.json with new token mappings."""
for token_id, new_value in token_id_mappings.items(): config_path = os.path.join(self.tokenizer_dir, "tokenizer_config.json")
token_id_str = str(token_id) if not os.path.exists(config_path):
if token_id_str in config_data["added_tokens_decoder"]: return
config_data["added_tokens_decoder"][token_id_str][
"content"
] = new_value
else:
raise ValueError(
f"Token ID {token_id_str} not found in added_tokens_decoder"
)
# Write the updated config back with open(config_path, "r", encoding="utf-8") as f:
with open(config_path, "w", encoding="utf-8") as f: config_data = json.load(f)
json.dump(config_data, f, indent=2)
# 2. Update tokenizer.json - added_tokens if "added_tokens_decoder" in config_data:
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") self._update_added_tokens_decoder(config_data, token_id_mappings)
if os.path.exists(tokenizer_path):
with open(tokenizer_path, "r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
# Update added_tokens with open(config_path, "w", encoding="utf-8") as f:
if "added_tokens" in tokenizer_data: json.dump(config_data, f, indent=2)
for token_id, new_value in token_id_mappings.items():
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
if token_entry["id"] == token_id:
tokenizer_data["added_tokens"][i]["content"] = new_value
break
else:
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
raise ValueError(
f"Token ID {token_id} not found in added_tokens"
)
if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]:
for token_id, new_value in token_id_mappings.items():
for entry_val, entry_id in tokenizer_data["model"]["vocab"].items():
if entry_id == token_id:
del tokenizer_data["model"]["vocab"][entry_val]
tokenizer_data["model"]["vocab"][new_value] = token_id
break
# Write the updated tokenizer data back def _update_added_tokens_decoder(
with open(tokenizer_path, "w", encoding="utf-8") as f: self, config_data: Dict, token_id_mappings: Dict[int, str]
json.dump(tokenizer_data, f, indent=2) ):
"""Update the added_tokens_decoder section."""
for token_id, new_value in token_id_mappings.items():
token_id_str = str(token_id)
if token_id_str in config_data["added_tokens_decoder"]:
config_data["added_tokens_decoder"][token_id_str]["content"] = new_value
else:
raise ValueError(
f"Token ID {token_id_str} not found in added_tokens_decoder"
)
barrier() def _update_tokenizer_json(self, token_id_mappings: Dict[int, str]):
return tokenizer_dir """Update tokenizer.json with new token mappings."""
tokenizer_json_path = os.path.join(self.tokenizer_dir, "tokenizer.json")
if not os.path.exists(tokenizer_json_path):
return
with open(tokenizer_json_path, "r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
self._update_added_tokens_list(tokenizer_data, token_id_mappings)
self._update_vocab_mappings(tokenizer_data, token_id_mappings)
with open(tokenizer_json_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_data, f, indent=2)
def _update_added_tokens_list(
self, tokenizer_data: Dict, token_id_mappings: Dict[int, str]
):
"""Update the added_tokens list in tokenizer.json."""
if "added_tokens" not in tokenizer_data:
return
for token_id, new_value in token_id_mappings.items():
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
if token_entry["id"] == token_id:
tokenizer_data["added_tokens"][i]["content"] = new_value
break
else:
raise ValueError(f"Token ID {token_id} not found in added_tokens")
def _update_vocab_mappings(
self, tokenizer_data: Dict, token_id_mappings: Dict[int, str]
):
"""Update vocab mappings in tokenizer.json."""
if not (tokenizer_data.get("model") and tokenizer_data["model"].get("vocab")):
return
vocab = tokenizer_data["model"]["vocab"]
for token_id, new_value in token_id_mappings.items():
# Find and update the vocab entry
for entry_val, entry_id in list(vocab.items()):
if entry_id == token_id:
del vocab[entry_val]
vocab[new_value] = token_id
break
class TokenizerConfiguration:
"""Handles tokenizer configuration and initialization."""
def __init__(self, cfg):
self.cfg = cfg
self.model_config = load_model_config(cfg)
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
"""Load Mistral tokenizer from model configuration."""
# Instantiate Mistral tokenizer
model_id = self.cfg.base_model
mistral_tokenizer = MistralTokenizer.from_hf_hub(model_id)
# Wrap it for compatibility
tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
return tokenizer
def get_tokenizer_class(self):
"""Get the appropriate tokenizer class."""
if self.cfg.tokenizer_type:
return getattr(transformers, self.cfg.tokenizer_type)
return AutoTokenizer
def get_tokenizer_kwargs(self) -> Dict[str, Any]:
"""Build tokenizer initialization kwargs."""
kwargs = {}
if self.cfg.tokenizer_legacy is not None:
kwargs["legacy"] = self.cfg.tokenizer_legacy
return kwargs
def get_tokenizer_path(self) -> str:
"""Get the tokenizer path, applying overrides if needed."""
tokenizer_path = self.cfg.tokenizer_config
if self.cfg.added_tokens_overrides:
modifier = TokenizerFileModifier(
tokenizer_path, self.cfg.added_tokens_overrides, self.cfg.output_dir
)
tokenizer_path = modifier.modify_and_save()
return tokenizer_path
def should_use_fast_tokenizer(self) -> bool:
"""Determine if fast tokenizer should be used."""
return (
self.cfg.tokenizer_use_fast
if self.cfg.tokenizer_use_fast is not None
else True
)
class TokenizerPostProcessor:
"""Handles post-processing configuration of loaded tokenizers."""
def __init__(self, tokenizer, cfg):
self.tokenizer = tokenizer
self.cfg = cfg
self.model_config = load_model_config(cfg)
def apply_all_configurations(self):
"""Apply all post-processing configurations to the tokenizer."""
# Skip most configurations for Mistral wrapper
if isinstance(self.tokenizer, MistralTokenizerWrapper):
self._configure_mistral_wrapper()
return
self._configure_padding_token()
self._configure_gptneox_settings()
self._configure_mistral_padding()
self._configure_qwen_tokens()
self._add_special_tokens()
self._add_regular_tokens()
self._configure_chat_template()
def _configure_mistral_wrapper(self):
"""Apply limited configurations for Mistral wrapper."""
# Set padding side if needed
if (
self.cfg.is_mistral_derived_model
and self.cfg.flash_attention
and not self.cfg.sample_packing
):
self.tokenizer.padding_side = "left"
# Configure chat template for Mistral
self._configure_chat_template()
def _configure_padding_token(self):
"""Configure padding token for Llama-based tokenizers."""
if (
self.tokenizer.__class__.__name__ in LLAMA_TOKENIZER_CLASSES
and hasattr(self.tokenizer, "pad_token")
and not self.tokenizer.pad_token
):
self.tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
def _configure_gptneox_settings(self):
"""Configure GPTNeoX-specific settings."""
if self.tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
self.tokenizer.add_special_tokens({"pad_token": GPTNEOX_PAD_TOKEN})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def _configure_mistral_padding(self):
"""Configure left padding for Mistral models with Flash Attention."""
if (
self.cfg.is_mistral_derived_model
and self.cfg.flash_attention
and not self.cfg.sample_packing
):
self.tokenizer.padding_side = "left"
def _configure_qwen_tokens(self):
"""Configure special tokens for Qwen models."""
if not self.cfg.is_qwen_derived_model:
return
# Set token IDs
token_id_attributes = [
"bos_token_id",
"eos_token_id",
"pad_token_id",
"unk_token_id",
]
for attr_name in token_id_attributes:
if getattr(self.tokenizer, attr_name) is None:
setattr(self.tokenizer, attr_name, self.tokenizer.eod_id)
# Set token strings
token_name_attributes = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_name_attributes:
if getattr(self.tokenizer, attr_name) is None:
setattr(self.tokenizer, attr_name, QWEN_DEFAULT_TOKEN)
def _add_special_tokens(self):
"""Add special tokens from configuration."""
if not self.cfg.special_tokens:
return
special_tokens_dict = self.cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens_dict.pop(
"additional_special_tokens", None
)
self._validate_and_add_special_tokens(special_tokens_dict)
self._update_post_processor_if_needed(special_tokens_dict)
self._add_additional_special_tokens_if_present(additional_special_tokens)
def _validate_and_add_special_tokens(self, special_tokens: Dict[str, str]):
"""Validate special tokens for adapter training and add them."""
lora_modules_to_save = get_linear_embedding_layers(self.model_config.model_type)
for key, value in special_tokens.items():
self._validate_token_for_adapter(key, value, lora_modules_to_save)
self.tokenizer.add_special_tokens(
{key: AddedToken(value, rstrip=False, lstrip=False, normalized=False)}
)
def _validate_token_for_adapter(
self, key: str, value: str, lora_modules_to_save: List[str]
):
"""Validate a single token for adapter training requirements."""
if not self._should_validate_token_for_adapter(
key, value, lora_modules_to_save
):
return
modules_str = ", ".join(f"`{x}`" for x in lora_modules_to_save)
raise ValueError(
f"Please set lora_modules_to_save to [{modules_str}] "
f"when using an adapter and changing the special tokens."
)
def _should_validate_token_for_adapter(
self, key: str, value: str, lora_modules_to_save: List[str]
) -> bool:
"""Check if token should be validated for adapter configuration."""
if key == "pad_token" or not self.cfg.adapter:
return False
current_token = getattr(self.tokenizer, key)
token_changed = current_token is None or current_token != value
token_is_multi_char = (
len(self.tokenizer.encode(value, add_special_tokens=False)) > 2
)
lora_modules_missing = not self.cfg.lora_modules_to_save or not all(
x in self.cfg.lora_modules_to_save for x in lora_modules_to_save
)
return token_changed and token_is_multi_char and lora_modules_missing
def _update_post_processor_if_needed(self, special_tokens: Dict[str, str]):
"""Update post processor for Llama tokenizers when BOS/EOS tokens are added."""
has_bos_and_eos = (
"bos_token" in special_tokens and "eos_token" in special_tokens
)
is_fast_llama = (
self.tokenizer.__class__.__name__ in FAST_LLAMA_TOKENIZER_CLASSES
)
if is_fast_llama and has_bos_and_eos:
self.tokenizer.update_post_processor()
def _add_additional_special_tokens_if_present(
self, additional_special_tokens: Optional[List[str]]
):
"""Add additional special tokens if they exist."""
if additional_special_tokens is not None:
self.tokenizer.add_special_tokens(
{"additional_special_tokens": additional_special_tokens}
)
def _add_regular_tokens(self):
"""Add regular (non-special) tokens from configuration."""
if self.cfg.tokens:
self.tokenizer.add_tokens(
[
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
for token in self.cfg.tokens
]
)
def _configure_chat_template(self):
"""Configure chat template if specified."""
if not self.cfg.chat_template:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)
return
chat_template_string = get_chat_template_from_config(
cfg=self.cfg,
tokenizer=self.tokenizer,
)
if self._should_replace_default_system_message():
chat_template_string = chat_template_string.replace(
CHATML_DEFAULT_SYSTEM_MESSAGE, self.cfg.default_system_message
)
self.tokenizer.chat_template = chat_template_string
def _should_replace_default_system_message(self) -> bool:
"""Check if default system message should be replaced."""
return self.cfg.default_system_message and self.cfg.chat_template == "chatml"
def load_tokenizer(cfg): def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config.""" """Load and configure the tokenizer based on the provided config.
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None: This function handles the complete tokenizer loading pipeline:
use_fast = cfg.tokenizer_use_fast - Check if Mistral tokenizer should be used
if cfg.tokenizer_legacy is not None: - Configure tokenizer parameters and get the appropriate class
# True is the default w/ https://github.com/huggingface/transformers/pull/25224 - Handle token file modifications if needed
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy - Initialize the tokenizer with the correct parameters
- Apply all post-processing configurations (padding, special tokens, etc.)
- Set up chat templates and logging
tokenizer_cls = AutoTokenizer Args:
if cfg.tokenizer_type: cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
# Set base tokenizer path Returns:
tokenizer_path = cfg.tokenizer_config Fully configured tokenizer instance.
"""
# Configure tokenizer parameters
config = TokenizerConfiguration(cfg)
# Apply token string overrides if specified # Check if we should use Mistral tokenizer
if cfg.added_tokens_overrides: try:
# Modify tokenizer files and get path to modified tokenizer tokenizer = config.load_mistral_tokenizer()
tokenizer_path = modify_tokenizer_files( except:
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir # Standard tokenizer loading
tokenizer_cls = config.get_tokenizer_class()
tokenizer_path = config.get_tokenizer_path()
use_fast = config.should_use_fast_tokenizer()
tokenizer_kwargs = config.get_tokenizer_kwargs()
# Initialize the tokenizer
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_path,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
) )
tokenizer = tokenizer_cls.from_pretrained( # Apply all post-processing configurations
tokenizer_path, post_processor = TokenizerPostProcessor(tokenizer, cfg)
trust_remote_code=cfg.trust_remote_code or False, post_processor.apply_all_configurations()
use_fast=use_fast,
**tokenizer_kwargs,
)
if (
tokenizer.__class__.__name__
in [
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
"CodeLlamaTokenizerFast",
]
and hasattr(tokenizer, "pad_token")
and not tokenizer.pad_token
):
# set a pad_token, but use eos_token so we don't add a new token
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Mistral's official FA implementation requires left padding
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens
if cfg.is_qwen_derived_model:
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, tokenizer.eod_id)
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_names:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens.pop(
"additional_special_tokens", None
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
for k, val in special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
# pylint: disable=too-many-boolean-expressions
if (
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
and cfg.adapter
and (
not cfg.lora_modules_to_save
or not all(
x in cfg.lora_modules_to_save for x in lora_modules_to_save
)
)
and k != "pad_token"
):
lora_modules_to_save = ", ".join(
[f"`{x}`" for x in lora_modules_to_save]
)
raise ValueError(
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
)
tokenizer.add_special_tokens(
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
)
# If we add bos_token and eos_token, we need to update the post processor to
# handle them correctly.
# https://github.com/huggingface/transformers/pull/24132
bos_or_eos_in_special_tokens = (
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
)
if (
tokenizer.__class__.__name__
in (
"LlamaTokenizerFast",
"CodeLlamaTokenizerFast",
)
and bos_or_eos_in_special_tokens
):
tokenizer.update_post_processor()
if cfg.tokens:
tokenizer.add_tokens(
[
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
for token in cfg.tokens
]
)
# Additional special tokens are a List, and need to be treated differently than regular special
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
# are new tokens.
#
# Usage:
#
# ```py
# special_tokens:
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
# ```
if additional_special_tokens is not None:
tokenizer.add_special_tokens(
{"additional_special_tokens": additional_special_tokens}
)
if is_main_process(use_environ=True): if is_main_process(use_environ=True):
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
@@ -263,19 +654,4 @@ def load_tokenizer(cfg):
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if cfg.chat_template:
chat_template_string = get_chat_template_from_config(
cfg=cfg,
tokenizer=tokenizer,
)
if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message
)
tokenizer.chat_template = chat_template_string
else:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)
return tokenizer return tokenizer

View File

@@ -67,6 +67,10 @@ class PromptTokenizingStrategy(abc.ABC):
LOG.warning("Empty text requested for tokenization.") LOG.warning("Empty text requested for tokenization.")
return empty return empty
import ipdb
ipdb.set_trace()
result = self.tokenizer( result = self.tokenizer(
prompt, prompt,
truncation=True, truncation=True,

View File

@@ -486,6 +486,10 @@ def get_dataset_wrapper(
f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}" f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
) )
import ipdb
ipdb.set_trace()
if ( if (
isinstance(dataset, Dataset) isinstance(dataset, Dataset)
and "input_ids" in dataset.features and "input_ids" in dataset.features

View File

@@ -1,8 +1,4 @@
""" """Test cases for tokenizer loading."""
Test cases for the tokenizer loading
"""
import unittest
import pytest import pytest
@@ -13,9 +9,7 @@ from tests.hf_offline_utils import enable_hf_offline
class TestTokenizers: class TestTokenizers:
""" """Test class for the load_tokenizer fn"""
test class for the load_tokenizer fn
"""
@enable_hf_offline @enable_hf_offline
def test_default_use_fast(self): def test_default_use_fast(self):
@@ -155,6 +149,50 @@ class TestTokenizers:
): ):
load_tokenizer(cfg) load_tokenizer(cfg)
def test_mistral_tokenizer_auto_detection(self):
"""Test that Mistral models are auto-detected and use MistralTokenizerWrapper"""
cfg = DictDefault(
{
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
if __name__ == "__main__": def test_mixtral_tokenizer_auto_detection(self):
unittest.main() """Test that Mixtral models are auto-detected and use MistralTokenizerWrapper"""
cfg = DictDefault(
{
"base_model": "model-hub/Mixtral-8x7B-v0.1",
"tokenizer_config": "model-hub/Mixtral-8x7B-v0.1",
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
def test_mistral_tokenizer_basic_functionality(self):
"""Test basic encode/decode functionality of MistralTokenizerWrapper"""
cfg = DictDefault(
{
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
}
)
tokenizer = load_tokenizer(cfg)
# Test basic encoding
text = "Hello, world!"
tokens = tokenizer.encode(text)
assert isinstance(tokens, list)
assert len(tokens) > 0
# Test basic decoding
decoded = tokenizer.decode(tokens)
assert isinstance(decoded, str)
# Test token properties are accessible
assert hasattr(tokenizer, "eos_token_id")
assert hasattr(tokenizer, "bos_token_id")
assert isinstance(tokenizer.eos_token_id, int)
assert isinstance(tokenizer.bos_token_id, int)