small changes
This commit is contained in:
@@ -305,8 +305,8 @@ def load_model_and_tokenizer(
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
||||
config.
|
||||
Helper function for loading a model, tokenizer, and processor specified in the
|
||||
given `axolotl` config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
@@ -367,15 +367,8 @@ class TokenizerConfiguration:
|
||||
self.model_config = load_model_config(cfg)
|
||||
|
||||
def detect_by_model_name_mapping(self) -> bool:
|
||||
model_path = getattr(self.cfg, "model_name_or_path", "") or getattr(
|
||||
self.cfg, "base_model", ""
|
||||
)
|
||||
if not model_path:
|
||||
return False
|
||||
|
||||
# Extract model name from path
|
||||
model = model_path.split("/")[-1]
|
||||
|
||||
model = self.cfg.base_model.split("/")[-1]
|
||||
for model_name in MODEL_NAME_TO_TOKENIZER_CLS.keys():
|
||||
if model_name in model.lower():
|
||||
return True
|
||||
@@ -384,24 +377,15 @@ class TokenizerConfiguration:
|
||||
|
||||
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
|
||||
"""Load Mistral tokenizer from model configuration."""
|
||||
model_id = getattr(self.cfg, "model_name_or_path", None) or getattr(
|
||||
self.cfg, "base_model", None
|
||||
)
|
||||
if not model_id:
|
||||
raise ValueError(
|
||||
"model_name_or_path or base_model must be specified for Mistral tokenizer"
|
||||
)
|
||||
|
||||
# First try to use the model name mapping for direct instantiation
|
||||
model_name = model_id.split("/")[-1] # Extract model name from path
|
||||
tokenizer_factory = MODEL_NAME_TO_TOKENIZER_CLS[model_name]
|
||||
mistral_tokenizer = tokenizer_factory()
|
||||
# Instantiate Mistral tokenizer
|
||||
model_id = self.cfg.base_model
|
||||
mistral_tokenizer = MistralTokenizer.from_file(model_id)
|
||||
|
||||
# Wrap it for compatibility
|
||||
wrapped_tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
|
||||
|
||||
tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
|
||||
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
|
||||
return wrapped_tokenizer
|
||||
|
||||
return tokenizer
|
||||
|
||||
def get_tokenizer_class(self):
|
||||
"""Get the appropriate tokenizer class."""
|
||||
|
||||
Reference in New Issue
Block a user