From b1570ed0fa71c4da3b5491fe0d6703bde2b5ef82 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 29 May 2025 20:04:35 +0000 Subject: [PATCH] update --- src/axolotl/loaders/tokenizer.py | 116 ++++++++++++++++++++++++------- tests/test_tokenizers.py | 58 +++++++++++++--- 2 files changed, 138 insertions(+), 36 deletions(-) diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 08dcaac03..925db9bd2 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -35,7 +35,7 @@ LLAMA_TOKENIZER_CLASSES = { "CodeLlamaTokenizerFast", } FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"} -MISTRAL_MODEL_TYPES = {"mistral", "mistral3"} +MISTRAL_MODEL_TYPES = {"mistral", "mistral3", "mixtral"} QWEN_DEFAULT_TOKEN = "<|endoftext|>" # nosec B105 GPTNEOX_PAD_TOKEN = "[PAD]" # nosec B105 @@ -269,35 +269,99 @@ class TokenizerConfiguration: def should_use_mistral_tokenizer(self) -> bool: """Determine if Mistral tokenizer should be used.""" - # Explicit configuration return self.model_config.model_type in MISTRAL_MODEL_TYPES def load_mistral_tokenizer(self) -> MistralTokenizerWrapper: """Load Mistral tokenizer from model configuration.""" - model_id = getattr(self.cfg, "model_name_or_path", None) or getattr( - self.cfg, "base_model", None - ) - if not model_id: - raise ValueError( - "model_name_or_path or base_model must be specified for Mistral tokenizer" - ) + # Try to find the appropriate tokenizer file + model_id = self.cfg.base_model + tokenizer_file = self._find_mistral_tokenizer_file(model_id) + # Load the Mistral tokenizer and wrap for transformers compatibility + mistral_tokenizer = MistralTokenizer.from_file(tokenizer_file) + wrapped_tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id) + + LOG.info(f"Loaded Mistral tokenizer for model: {model_id}") + return wrapped_tokenizer + + def _find_mistral_tokenizer_file(self, model_id: str) -> str: + """Find the appropriate tokenizer file for the given model.""" + # Generate all possible SentencePiece suffixes based on mistral_common patterns + instruct_versions = ["v1", "v2", "v3", "v7"] + mm_versions = ["m1", ""] # multimodal versions + empty string + + # Create all possible .model file suffixes + sentencepiece_suffixes = [ + f".model.{v}{m}" for v in instruct_versions for m in mm_versions + ] + [".model"] + + # List of tokenizer files to try, in order of preference + candidate_files = ["tekken.json"] # Try Tekken first + + # Add SentencePiece candidates in preference order (newer versions first) + preferred_sp_files = [ + "tokenizer.model.v7", # Latest instruction version + "tokenizer.model.v7m1", # Latest with multimodal + "tokenizer.model.v3", # Common version + "tokenizer.model.v3m1", # v3 with multimodal + "tokenizer.model.v2", # Older version + "tokenizer.model.v2m1", # v2 with multimodal + "tokenizer.model.v1", # Oldest versioned + "tokenizer.model.v1m1", # v1 with multimodal + "tokenizer.model", # Generic fallback + ] + candidate_files.extend(preferred_sp_files) + + # Try each candidate file + for filename in candidate_files: + try: + tokenizer_file = hf_hub_download(repo_id=model_id, filename=filename) + LOG.debug(f"Found tokenizer file: {filename}") + return tokenizer_file + except Exception: + continue + + # If no standard files found, try to list and find any matching files try: - # Download the tekken.json file for the tokenizer - tekken_file = hf_hub_download(repo_id=model_id, filename="tekken.json") + from huggingface_hub import list_repo_files - # Load the Mistral tokenizer - mistral_tokenizer = MistralTokenizer.from_file(tekken_file) + repo_files = list_repo_files(repo_id=model_id) - # Wrap it for compatibility - wrapped_tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id) + # Look for any files matching the SentencePiece patterns + matching_files = [] + for repo_file in repo_files: + if any(repo_file.endswith(suffix) for suffix in sentencepiece_suffixes): + matching_files.append(repo_file) - LOG.info(f"Loaded Mistral tokenizer for model: {model_id}") - return wrapped_tokenizer + if matching_files: + # Sort by preference (newer versions first) + def sort_key(filename): + # Prioritize by version number (v7 > v3 > v2 > v1) + if "v7" in filename: + return 0 + elif "v3" in filename: + return 1 + elif "v2" in filename: + return 2 + elif "v1" in filename: + return 3 + else: + return 4 + + matching_files.sort(key=sort_key) + tokenizer_file = hf_hub_download( + repo_id=model_id, filename=matching_files[0] + ) + LOG.debug(f"Using discovered tokenizer file: {matching_files[0]}") + return tokenizer_file except Exception as e: - LOG.error(f"Failed to load Mistral tokenizer: {e}") - raise + LOG.debug(f"Could not list repo files: {e}") + + raise FileNotFoundError( + f"Could not find suitable tokenizer file for {model_id}. " + f"Tried: {', '.join(candidate_files[:5])}... and {len(candidate_files)-5} others" + ) def get_tokenizer_class(self): """Get the appropriate tokenizer class.""" @@ -546,17 +610,17 @@ def load_tokenizer(cfg): Fully configured tokenizer instance. """ # Configure tokenizer parameters - config = TokenizerConfiguration(cfg) + tokenizer_config = TokenizerConfiguration(cfg) # Check if we should use Mistral tokenizer - if config.should_use_mistral_tokenizer(): - tokenizer = config.load_mistral_tokenizer() + if tokenizer_config.should_use_mistral_tokenizer(): + tokenizer = tokenizer_config.load_mistral_tokenizer() else: # Standard tokenizer loading - tokenizer_cls = config.get_tokenizer_class() - tokenizer_path = config.get_tokenizer_path() - use_fast = config.should_use_fast_tokenizer() - tokenizer_kwargs = config.get_tokenizer_kwargs() + tokenizer_cls = tokenizer_config.get_tokenizer_class() + tokenizer_path = tokenizer_config.get_tokenizer_path() + use_fast = tokenizer_config.should_use_fast_tokenizer() + tokenizer_kwargs = tokenizer_config.get_tokenizer_kwargs() # Initialize the tokenizer tokenizer = tokenizer_cls.from_pretrained( diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py index 406462038..104006c1b 100644 --- a/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -1,8 +1,4 @@ -""" -Test cases for the tokenizer loading -""" - -import unittest +"""Test cases for tokenizer loading.""" import pytest @@ -13,9 +9,7 @@ from tests.hf_offline_utils import enable_hf_offline class TestTokenizers: - """ - test class for the load_tokenizer fn - """ + """Test class for the load_tokenizer fn""" @enable_hf_offline def test_default_use_fast(self): @@ -155,6 +149,50 @@ class TestTokenizers: ): load_tokenizer(cfg) + def test_mistral_tokenizer_auto_detection(self): + """Test that Mistral models are auto-detected and use MistralTokenizerWrapper""" + cfg = DictDefault( + { + "base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated", + "tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated", + } + ) + tokenizer = load_tokenizer(cfg) + assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper" -if __name__ == "__main__": - unittest.main() + def test_mixtral_tokenizer_auto_detection(self): + """Test that Mixtral models are auto-detected and use MistralTokenizerWrapper""" + cfg = DictDefault( + { + "base_model": "model-hub/Mixtral-8x7B-v0.1", + "tokenizer_config": "model-hub/Mixtral-8x7B-v0.1", + } + ) + tokenizer = load_tokenizer(cfg) + assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper" + + def test_mistral_tokenizer_basic_functionality(self): + """Test basic encode/decode functionality of MistralTokenizerWrapper""" + cfg = DictDefault( + { + "base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated", + "tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated", + } + ) + tokenizer = load_tokenizer(cfg) + + # Test basic encoding + text = "Hello, world!" + tokens = tokenizer.encode(text) + assert isinstance(tokens, list) + assert len(tokens) > 0 + + # Test basic decoding + decoded = tokenizer.decode(tokens) + assert isinstance(decoded, str) + + # Test token properties are accessible + assert hasattr(tokenizer, "eos_token_id") + assert hasattr(tokenizer, "bos_token_id") + assert isinstance(tokenizer.eos_token_id, int) + assert isinstance(tokenizer.bos_token_id, int)