fix: pass revision parameter to tokenizer and processor loaders (#3388) [skip ci]
* fix: pass revision parameter to tokenizer and processor loaders * fix: address revision=None passed to .from_pretrained * add tests and address review feedback for revision parameter - Reformat modify_tokenizer_files signature and from_pretrained call - Use kwargs pattern for modify_tokenizer_files call to avoid passing None revision - Add 6 unit tests for revision parameter in tokenizer/processor loaders --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -19,6 +19,11 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
if cfg.processor_type:
|
||||
processor_cls = getattr(transformers, cfg.processor_type)
|
||||
|
||||
# Build common kwargs for processor loading
|
||||
processor_kwargs = {}
|
||||
if cfg.revision_of_model:
|
||||
processor_kwargs["revision"] = cfg.revision_of_model
|
||||
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
|
||||
def _patch_mistralcommontokenizer():
|
||||
@@ -40,6 +45,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
if processor_cls == VoxtralProcessor:
|
||||
return VoxtralProcessor.from_pretrained(
|
||||
cfg.processor_config,
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
from axolotl.utils.mistral import Mistral3Processor
|
||||
@@ -48,10 +54,12 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
|
||||
processor_kwargs["tokenizer"] = tokenizer
|
||||
|
||||
processor = processor_cls.from_pretrained(
|
||||
cfg.processor_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
tokenizer=tokenizer,
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
# Attempt to load image size from processor if available
|
||||
|
||||
@@ -28,7 +28,10 @@ PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
|
||||
def modify_tokenizer_files(
|
||||
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
|
||||
tokenizer_path: str,
|
||||
token_mappings: dict[int, str],
|
||||
output_dir: str,
|
||||
revision: str = "main",
|
||||
) -> str:
|
||||
"""
|
||||
Modify tokenizer files to replace added_tokens strings, save to output directory,
|
||||
@@ -41,6 +44,7 @@ def modify_tokenizer_files(
|
||||
tokenizer_path: Path or name of the original tokenizer
|
||||
token_mappings: Dict mapping {token_id (int): new_token_string}
|
||||
output_dir: Directory to save the modified tokenizer
|
||||
revision: Model revision/branch/tag/commit to load from (HF Hub)
|
||||
|
||||
Returns:
|
||||
Path to the modified tokenizer directory
|
||||
@@ -53,7 +57,9 @@ def modify_tokenizer_files(
|
||||
|
||||
if is_local_main_process():
|
||||
# Load the tokenizer
|
||||
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
|
||||
temp_tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, use_fast=True, revision=revision
|
||||
)
|
||||
|
||||
# Save the tokenizer to the output directory
|
||||
temp_tokenizer.save_pretrained(tokenizer_dir)
|
||||
@@ -134,7 +140,10 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
from axolotl.utils.mistral import HFMistralTokenizer
|
||||
|
||||
# Load the HF-compatible wrapper around MistralTokenizer
|
||||
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
|
||||
kwargs = {}
|
||||
if cfg.revision_of_model:
|
||||
kwargs["revision"] = cfg.revision_of_model
|
||||
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config, **kwargs)
|
||||
|
||||
return tokenizer
|
||||
|
||||
@@ -150,6 +159,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
if cfg.tokenizer_legacy is not None:
|
||||
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||
if cfg.revision_of_model:
|
||||
tokenizer_kwargs["revision"] = cfg.revision_of_model
|
||||
|
||||
tokenizer_cls = AutoTokenizer
|
||||
if cfg.tokenizer_type:
|
||||
@@ -161,8 +172,11 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
# Apply token string overrides if specified
|
||||
if cfg.added_tokens_overrides:
|
||||
# Modify tokenizer files and get path to modified tokenizer
|
||||
modify_kwargs = {"output_dir": cfg.output_dir}
|
||||
if cfg.revision_of_model:
|
||||
modify_kwargs["revision"] = cfg.revision_of_model
|
||||
tokenizer_path = modify_tokenizer_files(
|
||||
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
|
||||
tokenizer_path, cfg.added_tokens_overrides, **modify_kwargs
|
||||
)
|
||||
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user