small changes
This commit is contained in:
@@ -305,8 +305,8 @@ def load_model_and_tokenizer(
|
|||||||
ProcessorMixin | None,
|
ProcessorMixin | None,
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
Helper function for loading a model, tokenizer, and processor specified in the
|
||||||
config.
|
given `axolotl` config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
|||||||
@@ -367,15 +367,8 @@ class TokenizerConfiguration:
|
|||||||
self.model_config = load_model_config(cfg)
|
self.model_config = load_model_config(cfg)
|
||||||
|
|
||||||
def detect_by_model_name_mapping(self) -> bool:
|
def detect_by_model_name_mapping(self) -> bool:
|
||||||
model_path = getattr(self.cfg, "model_name_or_path", "") or getattr(
|
|
||||||
self.cfg, "base_model", ""
|
|
||||||
)
|
|
||||||
if not model_path:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Extract model name from path
|
# Extract model name from path
|
||||||
model = model_path.split("/")[-1]
|
model = self.cfg.base_model.split("/")[-1]
|
||||||
|
|
||||||
for model_name in MODEL_NAME_TO_TOKENIZER_CLS.keys():
|
for model_name in MODEL_NAME_TO_TOKENIZER_CLS.keys():
|
||||||
if model_name in model.lower():
|
if model_name in model.lower():
|
||||||
return True
|
return True
|
||||||
@@ -384,24 +377,15 @@ class TokenizerConfiguration:
|
|||||||
|
|
||||||
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
|
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
|
||||||
"""Load Mistral tokenizer from model configuration."""
|
"""Load Mistral tokenizer from model configuration."""
|
||||||
model_id = getattr(self.cfg, "model_name_or_path", None) or getattr(
|
# Instantiate Mistral tokenizer
|
||||||
self.cfg, "base_model", None
|
model_id = self.cfg.base_model
|
||||||
)
|
mistral_tokenizer = MistralTokenizer.from_file(model_id)
|
||||||
if not model_id:
|
|
||||||
raise ValueError(
|
|
||||||
"model_name_or_path or base_model must be specified for Mistral tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
# First try to use the model name mapping for direct instantiation
|
|
||||||
model_name = model_id.split("/")[-1] # Extract model name from path
|
|
||||||
tokenizer_factory = MODEL_NAME_TO_TOKENIZER_CLS[model_name]
|
|
||||||
mistral_tokenizer = tokenizer_factory()
|
|
||||||
|
|
||||||
# Wrap it for compatibility
|
# Wrap it for compatibility
|
||||||
wrapped_tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
|
tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
|
||||||
|
|
||||||
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
|
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
|
||||||
return wrapped_tokenizer
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
def get_tokenizer_class(self):
|
def get_tokenizer_class(self):
|
||||||
"""Get the appropriate tokenizer class."""
|
"""Get the appropriate tokenizer class."""
|
||||||
|
|||||||
Reference in New Issue
Block a user