fix: pass revision parameter to tokenizer and processor loaders (#3388) [skip ci]
* fix: pass revision parameter to tokenizer and processor loaders * fix: address revision=None passed to .from_pretrained * add tests and address review feedback for revision parameter - Reformat modify_tokenizer_files signature and from_pretrained call - Use kwargs pattern for modify_tokenizer_files call to avoid passing None revision - Add 6 unit tests for revision parameter in tokenizer/processor loaders --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -19,6 +19,11 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
if cfg.processor_type:
|
||||
processor_cls = getattr(transformers, cfg.processor_type)
|
||||
|
||||
# Build common kwargs for processor loading
|
||||
processor_kwargs = {}
|
||||
if cfg.revision_of_model:
|
||||
processor_kwargs["revision"] = cfg.revision_of_model
|
||||
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
|
||||
def _patch_mistralcommontokenizer():
|
||||
@@ -40,6 +45,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
if processor_cls == VoxtralProcessor:
|
||||
return VoxtralProcessor.from_pretrained(
|
||||
cfg.processor_config,
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
from axolotl.utils.mistral import Mistral3Processor
|
||||
@@ -48,10 +54,12 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
|
||||
processor_kwargs["tokenizer"] = tokenizer
|
||||
|
||||
processor = processor_cls.from_pretrained(
|
||||
cfg.processor_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
tokenizer=tokenizer,
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
# Attempt to load image size from processor if available
|
||||
|
||||
@@ -28,7 +28,10 @@ PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
|
||||
def modify_tokenizer_files(
|
||||
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
|
||||
tokenizer_path: str,
|
||||
token_mappings: dict[int, str],
|
||||
output_dir: str,
|
||||
revision: str = "main",
|
||||
) -> str:
|
||||
"""
|
||||
Modify tokenizer files to replace added_tokens strings, save to output directory,
|
||||
@@ -41,6 +44,7 @@ def modify_tokenizer_files(
|
||||
tokenizer_path: Path or name of the original tokenizer
|
||||
token_mappings: Dict mapping {token_id (int): new_token_string}
|
||||
output_dir: Directory to save the modified tokenizer
|
||||
revision: Model revision/branch/tag/commit to load from (HF Hub)
|
||||
|
||||
Returns:
|
||||
Path to the modified tokenizer directory
|
||||
@@ -53,7 +57,9 @@ def modify_tokenizer_files(
|
||||
|
||||
if is_local_main_process():
|
||||
# Load the tokenizer
|
||||
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
|
||||
temp_tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, use_fast=True, revision=revision
|
||||
)
|
||||
|
||||
# Save the tokenizer to the output directory
|
||||
temp_tokenizer.save_pretrained(tokenizer_dir)
|
||||
@@ -134,7 +140,10 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
from axolotl.utils.mistral import HFMistralTokenizer
|
||||
|
||||
# Load the HF-compatible wrapper around MistralTokenizer
|
||||
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
|
||||
kwargs = {}
|
||||
if cfg.revision_of_model:
|
||||
kwargs["revision"] = cfg.revision_of_model
|
||||
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config, **kwargs)
|
||||
|
||||
return tokenizer
|
||||
|
||||
@@ -150,6 +159,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
if cfg.tokenizer_legacy is not None:
|
||||
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||
if cfg.revision_of_model:
|
||||
tokenizer_kwargs["revision"] = cfg.revision_of_model
|
||||
|
||||
tokenizer_cls = AutoTokenizer
|
||||
if cfg.tokenizer_type:
|
||||
@@ -161,8 +172,11 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
# Apply token string overrides if specified
|
||||
if cfg.added_tokens_overrides:
|
||||
# Modify tokenizer files and get path to modified tokenizer
|
||||
modify_kwargs = {"output_dir": cfg.output_dir}
|
||||
if cfg.revision_of_model:
|
||||
modify_kwargs["revision"] = cfg.revision_of_model
|
||||
tokenizer_path = modify_tokenizer_files(
|
||||
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
|
||||
tokenizer_path, cfg.added_tokens_overrides, **modify_kwargs
|
||||
)
|
||||
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
|
||||
135
tests/test_revision_parameter.py
Normal file
135
tests/test_revision_parameter.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Tests for revision_of_model being passed to tokenizer and processor loaders."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestRevisionParameter:
|
||||
"""Tests for revision_of_model being passed to tokenizer and processor loaders."""
|
||||
|
||||
@patch("axolotl.loaders.tokenizer.load_model_config")
|
||||
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
||||
@patch(
|
||||
"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches"
|
||||
)
|
||||
def test_load_tokenizer_passes_revision(
|
||||
self, _mock_patches, mock_auto_tokenizer, _mock_load_config
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.__class__.__name__ = "MockTokenizer"
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "some-model",
|
||||
"revision_of_model": "abc123",
|
||||
}
|
||||
)
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
|
||||
load_tokenizer(cfg)
|
||||
|
||||
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
||||
assert call_kwargs.kwargs.get("revision") == "abc123"
|
||||
|
||||
@patch("axolotl.loaders.tokenizer.load_model_config")
|
||||
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
||||
@patch(
|
||||
"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches"
|
||||
)
|
||||
def test_load_tokenizer_omits_revision_when_unset(
|
||||
self, _mock_patches, mock_auto_tokenizer, _mock_load_config
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.__class__.__name__ = "MockTokenizer"
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "some-model",
|
||||
}
|
||||
)
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
|
||||
load_tokenizer(cfg)
|
||||
|
||||
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
||||
assert "revision" not in call_kwargs.kwargs
|
||||
|
||||
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
||||
@patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True)
|
||||
@patch("axolotl.loaders.tokenizer.barrier")
|
||||
def test_modify_tokenizer_files_passes_revision(
|
||||
self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
from axolotl.loaders.tokenizer import modify_tokenizer_files
|
||||
|
||||
modify_tokenizer_files("some-model", {}, output_dir=temp_dir, revision="abc123")
|
||||
|
||||
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
||||
assert call_kwargs.kwargs.get("revision") == "abc123"
|
||||
|
||||
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
||||
@patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True)
|
||||
@patch("axolotl.loaders.tokenizer.barrier")
|
||||
def test_modify_tokenizer_files_defaults_revision_to_main(
|
||||
self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
from axolotl.loaders.tokenizer import modify_tokenizer_files
|
||||
|
||||
modify_tokenizer_files("some-model", {}, output_dir=temp_dir)
|
||||
|
||||
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
||||
assert call_kwargs.kwargs.get("revision") == "main"
|
||||
|
||||
@patch("axolotl.loaders.processor.AutoProcessor")
|
||||
def test_load_processor_passes_revision(self, mock_auto_processor):
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.size = {}
|
||||
mock_auto_processor.from_pretrained.return_value = mock_processor
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"processor_config": "some-model",
|
||||
"revision_of_model": "abc123",
|
||||
"trust_remote_code": False,
|
||||
}
|
||||
)
|
||||
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||
|
||||
from axolotl.loaders.processor import load_processor
|
||||
|
||||
load_processor(cfg, tokenizer)
|
||||
|
||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||
assert call_kwargs.kwargs.get("revision") == "abc123"
|
||||
|
||||
@patch("axolotl.loaders.processor.AutoProcessor")
|
||||
def test_load_processor_omits_revision_when_unset(self, mock_auto_processor):
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.size = {}
|
||||
mock_auto_processor.from_pretrained.return_value = mock_processor
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"processor_config": "some-model",
|
||||
"trust_remote_code": False,
|
||||
}
|
||||
)
|
||||
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||
|
||||
from axolotl.loaders.processor import load_processor
|
||||
|
||||
load_processor(cfg, tokenizer)
|
||||
|
||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||
assert "revision" not in call_kwargs.kwargs
|
||||
Reference in New Issue
Block a user