From 13ddb8f172156c12fd75c477802f4a99ef7ad7aa Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 5 Jun 2025 07:00:50 +0000 Subject: [PATCH] Simplify mistral tokenizer identification (depends on upstream PR) --- src/axolotl/loaders/tokenizer.py | 332 +++++++++++++++++-------------- 1 file changed, 186 insertions(+), 146 deletions(-) diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 925db9bd2..5837cb73d 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -1,4 +1,4 @@ -"""Tokenizer loading functionality and associated utils.""" +"""Tokenizer loading functionality and associated utils""" import json import os @@ -9,9 +9,14 @@ import transformers from huggingface_hub import hf_hub_download from mistral_common.protocol.instruct.messages import SystemMessage, UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest -from mistral_common.tokens.tokenizers.base import SpecialTokens -from mistral_common.tokens.tokenizers.mistral import MistralTokenizer -from transformers import AddedToken, AutoTokenizer +from mistral_common.tokens.tokenizers.mistral import ( + MODEL_NAME_TO_TOKENIZER_CLS, + MistralTokenizer, +) +from transformers import ( + AddedToken, + AutoTokenizer, +) from axolotl.integrations.base import PluginManager from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config @@ -34,51 +39,146 @@ LLAMA_TOKENIZER_CLASSES = { "CodeLlamaTokenizer", "CodeLlamaTokenizerFast", } -FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"} -MISTRAL_MODEL_TYPES = {"mistral", "mistral3", "mixtral"} -QWEN_DEFAULT_TOKEN = "<|endoftext|>" # nosec B105 -GPTNEOX_PAD_TOKEN = "[PAD]" # nosec B105 +FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"} + +QWEN_DEFAULT_TOKEN = "<|endoftext|>" +GPTNEOX_PAD_TOKEN = "[PAD]" CHATML_DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant." class MistralTokenizerWrapper: """ - Wrapper to make MistralTokenizer compatible with Hugging Face tokenizer - interface. This provides a bridge between Mistral's native tokenizer and axolotl's - expectations. + Wrapper to make MistralTokenizer compatible with Hugging Face tokenizer interface. + This provides a bridge between Mistral's native tokenizer and axolotl's expectations. """ - def __init__( - self, - mistral_tokenizer: MistralTokenizer, - model_id: str, - system_prompt: str | None = None, - ): + def __init__(self, mistral_tokenizer: "MistralTokenizer", model_id: str): self.mistral_tokenizer = mistral_tokenizer self.model_id = model_id - self.system_prompt = system_prompt + self._system_prompt = None self.padding_side = "right" # Default padding side self.chat_template = None - # pylint: disable=unused-argument + # Cache token IDs by inspecting the actual tokenizer + self._token_ids = self._discover_token_ids() + + # Try to load system prompt if available + try: + self._system_prompt = self._load_system_prompt( + model_id, "SYSTEM_PROMPT.txt" + ) + except Exception as e: + LOG.debug(f"Could not load system prompt: {e}") + + def _discover_token_ids(self) -> Dict[str, int]: + """Discover the actual token IDs used by this Mistral tokenizer.""" + token_ids = {} + + try: + if hasattr(self.mistral_tokenizer, "instruct_tokenizer"): + instruct_tokenizer = self.mistral_tokenizer.instruct_tokenizer + + # Get BOS token ID from instruct_tokenizer + token_ids["bos_token_id"] = getattr(instruct_tokenizer, "BOS", 1) + + # Get token IDs from the underlying Tekkenizer + if hasattr(instruct_tokenizer, "tokenizer"): + tekkenizer = instruct_tokenizer.tokenizer + + # Get BOS ID from tekkenizer (should match instruct_tokenizer.BOS) + if hasattr(tekkenizer, "bos_id"): + token_ids["bos_token_id"] = tekkenizer.bos_id + + # Get vocab size to help find EOS token + vocab_size = getattr(tekkenizer, "_vocab_size", None) + + # Check special tokens + if hasattr(tekkenizer, "_all_special_tokens"): + special_tokens = tekkenizer._all_special_tokens + keys = ( + list(special_tokens.keys()) + if hasattr(special_tokens, "keys") + else special_tokens + ) + LOG.debug(f"Special tokens available: {keys}") + + # Try to find EOS token in special tokens + if hasattr(special_tokens, "get"): + # Common EOS token patterns + for eos_pattern in ["", "<|endoftext|>", "eos", "EOS"]: + if eos_pattern in special_tokens: + token_ids["eos_token_id"] = special_tokens[ + eos_pattern + ] + break + + # Check special tokens reverse vocab + if hasattr(tekkenizer, "_special_tokens_reverse_vocab"): + reverse_vocab = tekkenizer._special_tokens_reverse_vocab + LOG.debug(f"Reverse special tokens: {reverse_vocab}") + + # Look for common special token IDs + for token_id, token_str in reverse_vocab.items(): + if token_str in ["", "<|endoftext|>"]: + token_ids["eos_token_id"] = token_id + elif token_str in ["", ""]: + token_ids["unk_token_id"] = token_id + + # If we have vocab_size, EOS is often vocab_size - 1 or similar + if "eos_token_id" not in token_ids and vocab_size: + # Common patterns: EOS could be 2, vocab_size-1, or other values + # Let's try a safer approach by checking what tokens decode to + for candidate_id in [2, vocab_size - 1, vocab_size - 2]: + try: + # Try to decode and see if it looks like EOS + decoded = tekkenizer.decode([candidate_id]) + if decoded in ["", "<|endoftext|>", ""]: + token_ids["eos_token_id"] = candidate_id + break + except Exception: + continue + + except Exception as e: + LOG.debug(f"Could not discover token IDs: {e}") + + # Set reasonable defaults for any missing token IDs + token_ids.setdefault("bos_token_id", 1) + token_ids.setdefault("eos_token_id", 2) + token_ids.setdefault("unk_token_id", 0) + token_ids.setdefault( + "pad_token_id", token_ids["eos_token_id"] + ) # Use EOS as pad + + LOG.info(f"Discovered Mistral token IDs: {token_ids}") + return token_ids + + def _load_system_prompt(self, repo_id: str, filename: str) -> str: + """Load system prompt from HuggingFace Hub""" + file_path = hf_hub_download(repo_id=repo_id, filename=filename) + with open(file_path, "r") as file: + return file.read() + def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> List[int]: """Encode text to token IDs""" - # For simple string encoding, create a user message - messages = [] - if self.system_prompt and add_special_tokens: - messages.append(SystemMessage(content=self.system_prompt)) - messages.append(UserMessage(content=text)) + if isinstance(text, str): + # For simple string encoding, create a user message + messages = [] + if self._system_prompt and add_special_tokens: + messages.append(SystemMessage(content=self._system_prompt)) + messages.append(UserMessage(content=text)) - tokenized = self.mistral_tokenizer.encode_chat_completion( - ChatCompletionRequest(messages=messages) - ) - return tokenized.tokens + tokenized = self.mistral_tokenizer.encode_chat_completion( + ChatCompletionRequest(messages=messages) + ) + return tokenized.tokens + else: + raise ValueError("MistralTokenizer wrapper only supports string input") def decode( self, token_ids: Union[List[int], torch.Tensor], - skip_special_tokens: bool = True, # pylint: disable=unused-argument + skip_special_tokens: bool = True, ) -> str: """Decode token IDs to text""" if isinstance(token_ids, torch.Tensor): @@ -91,19 +191,28 @@ class MistralTokenizerWrapper: return {"input_ids": torch.tensor([tokens])} @property - def special_tokens_reverse_vocab(self): - # pylint: disable=protected-access - return ( - self.mistral_tokenizer.instruct_tokenizer.tokenizer._special_tokens_reverse_vocab - ) + def eos_token_id(self): + return self._token_ids["eos_token_id"] + + @property + def bos_token_id(self): + return self._token_ids["bos_token_id"] + + @property + def pad_token_id(self): + return self._token_ids["pad_token_id"] + + @property + def unk_token_id(self): + return self._token_ids["unk_token_id"] @property def eos_token(self): - return SpecialTokens.eos + return "" # Standard Mistral EOS token @property def bos_token(self): - return SpecialTokens.bos + return "" # Standard Mistral BOS token @property def pad_token(self): @@ -111,25 +220,16 @@ class MistralTokenizerWrapper: @property def unk_token(self): - return SpecialTokens.unk + return "" # Standard UNK token @property - def eos_token_id(self): - return self.special_tokens_reverse_vocab[self.eos_token] + def __class__(self): + # Create a mock class for compatibility checks + class MistralTokenizerWrapperClass: + __name__ = "MistralTokenizerWrapper" - @property - def bos_token_id(self): - return self.special_tokens_reverse_vocab[self.bos_token] + return MistralTokenizerWrapperClass - @property - def pad_token_id(self): - return self.special_tokens_reverse_vocab[self.pad_token] - - @property - def unk_token_id(self): - return self.special_tokens_reverse_vocab[self.unk_token] - - # pylint: disable=unused-argument def add_special_tokens(self, special_tokens_dict: Dict[str, str]) -> int: """Placeholder for special token addition - Mistral tokenizer handles this internally""" LOG.warning( @@ -137,7 +237,6 @@ class MistralTokenizerWrapper: ) return 0 - # pylint: disable=unused-argument def add_tokens(self, tokens) -> int: """Placeholder for token addition - Mistral tokenizer handles this internally""" LOG.warning( @@ -267,102 +366,43 @@ class TokenizerConfiguration: self.cfg = cfg self.model_config = load_model_config(cfg) - def should_use_mistral_tokenizer(self) -> bool: - """Determine if Mistral tokenizer should be used.""" - return self.model_config.model_type in MISTRAL_MODEL_TYPES + def detect_by_model_name_mapping(self) -> bool: + model_path = getattr(self.cfg, "model_name_or_path", "") or getattr( + self.cfg, "base_model", "" + ) + if not model_path: + return False + + # Extract model name from path + model = model_path.split("/")[-1] + + for model_name in MODEL_NAME_TO_TOKENIZER_CLS.keys(): + if model_name in model.lower(): + return True + + return False def load_mistral_tokenizer(self) -> MistralTokenizerWrapper: """Load Mistral tokenizer from model configuration.""" - # Try to find the appropriate tokenizer file - model_id = self.cfg.base_model - tokenizer_file = self._find_mistral_tokenizer_file(model_id) + model_id = getattr(self.cfg, "model_name_or_path", None) or getattr( + self.cfg, "base_model", None + ) + if not model_id: + raise ValueError( + "model_name_or_path or base_model must be specified for Mistral tokenizer" + ) - # Load the Mistral tokenizer and wrap for transformers compatibility - mistral_tokenizer = MistralTokenizer.from_file(tokenizer_file) + # First try to use the model name mapping for direct instantiation + model_name = model_id.split("/")[-1] # Extract model name from path + tokenizer_factory = MODEL_NAME_TO_TOKENIZER_CLS[model_name] + mistral_tokenizer = tokenizer_factory() + + # Wrap it for compatibility wrapped_tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id) LOG.info(f"Loaded Mistral tokenizer for model: {model_id}") return wrapped_tokenizer - def _find_mistral_tokenizer_file(self, model_id: str) -> str: - """Find the appropriate tokenizer file for the given model.""" - # Generate all possible SentencePiece suffixes based on mistral_common patterns - instruct_versions = ["v1", "v2", "v3", "v7"] - mm_versions = ["m1", ""] # multimodal versions + empty string - - # Create all possible .model file suffixes - sentencepiece_suffixes = [ - f".model.{v}{m}" for v in instruct_versions for m in mm_versions - ] + [".model"] - - # List of tokenizer files to try, in order of preference - candidate_files = ["tekken.json"] # Try Tekken first - - # Add SentencePiece candidates in preference order (newer versions first) - preferred_sp_files = [ - "tokenizer.model.v7", # Latest instruction version - "tokenizer.model.v7m1", # Latest with multimodal - "tokenizer.model.v3", # Common version - "tokenizer.model.v3m1", # v3 with multimodal - "tokenizer.model.v2", # Older version - "tokenizer.model.v2m1", # v2 with multimodal - "tokenizer.model.v1", # Oldest versioned - "tokenizer.model.v1m1", # v1 with multimodal - "tokenizer.model", # Generic fallback - ] - candidate_files.extend(preferred_sp_files) - - # Try each candidate file - for filename in candidate_files: - try: - tokenizer_file = hf_hub_download(repo_id=model_id, filename=filename) - LOG.debug(f"Found tokenizer file: {filename}") - return tokenizer_file - except Exception: - continue - - # If no standard files found, try to list and find any matching files - try: - from huggingface_hub import list_repo_files - - repo_files = list_repo_files(repo_id=model_id) - - # Look for any files matching the SentencePiece patterns - matching_files = [] - for repo_file in repo_files: - if any(repo_file.endswith(suffix) for suffix in sentencepiece_suffixes): - matching_files.append(repo_file) - - if matching_files: - # Sort by preference (newer versions first) - def sort_key(filename): - # Prioritize by version number (v7 > v3 > v2 > v1) - if "v7" in filename: - return 0 - elif "v3" in filename: - return 1 - elif "v2" in filename: - return 2 - elif "v1" in filename: - return 3 - else: - return 4 - - matching_files.sort(key=sort_key) - tokenizer_file = hf_hub_download( - repo_id=model_id, filename=matching_files[0] - ) - LOG.debug(f"Using discovered tokenizer file: {matching_files[0]}") - return tokenizer_file - - except Exception as e: - LOG.debug(f"Could not list repo files: {e}") - - raise FileNotFoundError( - f"Could not find suitable tokenizer file for {model_id}. " - f"Tried: {', '.join(candidate_files[:5])}... and {len(candidate_files)-5} others" - ) - def get_tokenizer_class(self): """Get the appropriate tokenizer class.""" if self.cfg.tokenizer_type: @@ -610,17 +650,17 @@ def load_tokenizer(cfg): Fully configured tokenizer instance. """ # Configure tokenizer parameters - tokenizer_config = TokenizerConfiguration(cfg) + config = TokenizerConfiguration(cfg) # Check if we should use Mistral tokenizer - if tokenizer_config.should_use_mistral_tokenizer(): - tokenizer = tokenizer_config.load_mistral_tokenizer() + if config.detect_by_model_name_mapping(): + tokenizer = config.load_mistral_tokenizer() else: # Standard tokenizer loading - tokenizer_cls = tokenizer_config.get_tokenizer_class() - tokenizer_path = tokenizer_config.get_tokenizer_path() - use_fast = tokenizer_config.should_use_fast_tokenizer() - tokenizer_kwargs = tokenizer_config.get_tokenizer_kwargs() + tokenizer_cls = config.get_tokenizer_class() + tokenizer_path = config.get_tokenizer_path() + use_fast = config.should_use_fast_tokenizer() + tokenizer_kwargs = config.get_tokenizer_kwargs() # Initialize the tokenizer tokenizer = tokenizer_cls.from_pretrained(