small changes

This commit is contained in:
Dan Saunders
2025-06-05 22:36:46 +00:00
parent 13ddb8f172
commit 65f8988efd
2 changed files with 9 additions and 25 deletions

View File

@@ -305,8 +305,8 @@ def load_model_and_tokenizer(
ProcessorMixin | None,
]:
"""
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
config.
Helper function for loading a model, tokenizer, and processor specified in the
given `axolotl` config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.

View File

@@ -367,15 +367,8 @@ class TokenizerConfiguration:
self.model_config = load_model_config(cfg)
def detect_by_model_name_mapping(self) -> bool:
model_path = getattr(self.cfg, "model_name_or_path", "") or getattr(
self.cfg, "base_model", ""
)
if not model_path:
return False
# Extract model name from path
model = model_path.split("/")[-1]
model = self.cfg.base_model.split("/")[-1]
for model_name in MODEL_NAME_TO_TOKENIZER_CLS.keys():
if model_name in model.lower():
return True
@@ -384,24 +377,15 @@ class TokenizerConfiguration:
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
"""Load Mistral tokenizer from model configuration."""
model_id = getattr(self.cfg, "model_name_or_path", None) or getattr(
self.cfg, "base_model", None
)
if not model_id:
raise ValueError(
"model_name_or_path or base_model must be specified for Mistral tokenizer"
)
# First try to use the model name mapping for direct instantiation
model_name = model_id.split("/")[-1] # Extract model name from path
tokenizer_factory = MODEL_NAME_TO_TOKENIZER_CLS[model_name]
mistral_tokenizer = tokenizer_factory()
# Instantiate Mistral tokenizer
model_id = self.cfg.base_model
mistral_tokenizer = MistralTokenizer.from_file(model_id)
# Wrap it for compatibility
wrapped_tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
return wrapped_tokenizer
return tokenizer
def get_tokenizer_class(self):
"""Get the appropriate tokenizer class."""