From 65f8988efdcfee463890a75d4d001e03ad343315 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 5 Jun 2025 22:36:46 +0000 Subject: [PATCH] small changes --- src/axolotl/cli/utils.py | 4 ++-- src/axolotl/loaders/tokenizer.py | 30 +++++++----------------------- 2 files changed, 9 insertions(+), 25 deletions(-) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index d28795361..c30ad2f73 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -305,8 +305,8 @@ def load_model_and_tokenizer( ProcessorMixin | None, ]: """ - Helper function for loading a model, tokenizer, and processor specified in the given `axolotl` - config. + Helper function for loading a model, tokenizer, and processor specified in the + given `axolotl` config. Args: cfg: Dictionary mapping `axolotl` config keys to values. diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 5837cb73d..7c764941b 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -367,15 +367,8 @@ class TokenizerConfiguration: self.model_config = load_model_config(cfg) def detect_by_model_name_mapping(self) -> bool: - model_path = getattr(self.cfg, "model_name_or_path", "") or getattr( - self.cfg, "base_model", "" - ) - if not model_path: - return False - # Extract model name from path - model = model_path.split("/")[-1] - + model = self.cfg.base_model.split("/")[-1] for model_name in MODEL_NAME_TO_TOKENIZER_CLS.keys(): if model_name in model.lower(): return True @@ -384,24 +377,15 @@ class TokenizerConfiguration: def load_mistral_tokenizer(self) -> MistralTokenizerWrapper: """Load Mistral tokenizer from model configuration.""" - model_id = getattr(self.cfg, "model_name_or_path", None) or getattr( - self.cfg, "base_model", None - ) - if not model_id: - raise ValueError( - "model_name_or_path or base_model must be specified for Mistral tokenizer" - ) - - # First try to use the model name mapping for direct instantiation - model_name = model_id.split("/")[-1] # Extract model name from path - tokenizer_factory = MODEL_NAME_TO_TOKENIZER_CLS[model_name] - mistral_tokenizer = tokenizer_factory() + # Instantiate Mistral tokenizer + model_id = self.cfg.base_model + mistral_tokenizer = MistralTokenizer.from_file(model_id) # Wrap it for compatibility - wrapped_tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id) - + tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id) LOG.info(f"Loaded Mistral tokenizer for model: {model_id}") - return wrapped_tokenizer + + return tokenizer def get_tokenizer_class(self): """Get the appropriate tokenizer class."""