fix: pass revision parameter to tokenizer and processor loaders (#3388) [skip ci]
* fix: pass revision parameter to tokenizer and processor loaders * fix: address revision=None passed to .from_pretrained * add tests and address review feedback for revision parameter - Reformat modify_tokenizer_files signature and from_pretrained call - Use kwargs pattern for modify_tokenizer_files call to avoid passing None revision - Add 6 unit tests for revision parameter in tokenizer/processor loaders --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -19,6 +19,11 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
|||||||
if cfg.processor_type:
|
if cfg.processor_type:
|
||||||
processor_cls = getattr(transformers, cfg.processor_type)
|
processor_cls = getattr(transformers, cfg.processor_type)
|
||||||
|
|
||||||
|
# Build common kwargs for processor loading
|
||||||
|
processor_kwargs = {}
|
||||||
|
if cfg.revision_of_model:
|
||||||
|
processor_kwargs["revision"] = cfg.revision_of_model
|
||||||
|
|
||||||
if cfg.tokenizer_use_mistral_common:
|
if cfg.tokenizer_use_mistral_common:
|
||||||
|
|
||||||
def _patch_mistralcommontokenizer():
|
def _patch_mistralcommontokenizer():
|
||||||
@@ -40,6 +45,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
|||||||
if processor_cls == VoxtralProcessor:
|
if processor_cls == VoxtralProcessor:
|
||||||
return VoxtralProcessor.from_pretrained(
|
return VoxtralProcessor.from_pretrained(
|
||||||
cfg.processor_config,
|
cfg.processor_config,
|
||||||
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.utils.mistral import Mistral3Processor
|
from axolotl.utils.mistral import Mistral3Processor
|
||||||
@@ -48,10 +54,12 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
|
||||||
|
processor_kwargs["tokenizer"] = tokenizer
|
||||||
|
|
||||||
processor = processor_cls.from_pretrained(
|
processor = processor_cls.from_pretrained(
|
||||||
cfg.processor_config,
|
cfg.processor_config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
**processor_kwargs,
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attempt to load image size from processor if available
|
# Attempt to load image size from processor if available
|
||||||
|
|||||||
@@ -28,7 +28,10 @@ PLUGIN_MANAGER = PluginManager.get_instance()
|
|||||||
|
|
||||||
|
|
||||||
def modify_tokenizer_files(
|
def modify_tokenizer_files(
|
||||||
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
|
tokenizer_path: str,
|
||||||
|
token_mappings: dict[int, str],
|
||||||
|
output_dir: str,
|
||||||
|
revision: str = "main",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Modify tokenizer files to replace added_tokens strings, save to output directory,
|
Modify tokenizer files to replace added_tokens strings, save to output directory,
|
||||||
@@ -41,6 +44,7 @@ def modify_tokenizer_files(
|
|||||||
tokenizer_path: Path or name of the original tokenizer
|
tokenizer_path: Path or name of the original tokenizer
|
||||||
token_mappings: Dict mapping {token_id (int): new_token_string}
|
token_mappings: Dict mapping {token_id (int): new_token_string}
|
||||||
output_dir: Directory to save the modified tokenizer
|
output_dir: Directory to save the modified tokenizer
|
||||||
|
revision: Model revision/branch/tag/commit to load from (HF Hub)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path to the modified tokenizer directory
|
Path to the modified tokenizer directory
|
||||||
@@ -53,7 +57,9 @@ def modify_tokenizer_files(
|
|||||||
|
|
||||||
if is_local_main_process():
|
if is_local_main_process():
|
||||||
# Load the tokenizer
|
# Load the tokenizer
|
||||||
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
|
temp_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
tokenizer_path, use_fast=True, revision=revision
|
||||||
|
)
|
||||||
|
|
||||||
# Save the tokenizer to the output directory
|
# Save the tokenizer to the output directory
|
||||||
temp_tokenizer.save_pretrained(tokenizer_dir)
|
temp_tokenizer.save_pretrained(tokenizer_dir)
|
||||||
@@ -134,7 +140,10 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
from axolotl.utils.mistral import HFMistralTokenizer
|
from axolotl.utils.mistral import HFMistralTokenizer
|
||||||
|
|
||||||
# Load the HF-compatible wrapper around MistralTokenizer
|
# Load the HF-compatible wrapper around MistralTokenizer
|
||||||
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
|
kwargs = {}
|
||||||
|
if cfg.revision_of_model:
|
||||||
|
kwargs["revision"] = cfg.revision_of_model
|
||||||
|
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config, **kwargs)
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
@@ -150,6 +159,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
if cfg.tokenizer_legacy is not None:
|
if cfg.tokenizer_legacy is not None:
|
||||||
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||||
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||||
|
if cfg.revision_of_model:
|
||||||
|
tokenizer_kwargs["revision"] = cfg.revision_of_model
|
||||||
|
|
||||||
tokenizer_cls = AutoTokenizer
|
tokenizer_cls = AutoTokenizer
|
||||||
if cfg.tokenizer_type:
|
if cfg.tokenizer_type:
|
||||||
@@ -161,8 +172,11 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
# Apply token string overrides if specified
|
# Apply token string overrides if specified
|
||||||
if cfg.added_tokens_overrides:
|
if cfg.added_tokens_overrides:
|
||||||
# Modify tokenizer files and get path to modified tokenizer
|
# Modify tokenizer files and get path to modified tokenizer
|
||||||
|
modify_kwargs = {"output_dir": cfg.output_dir}
|
||||||
|
if cfg.revision_of_model:
|
||||||
|
modify_kwargs["revision"] = cfg.revision_of_model
|
||||||
tokenizer_path = modify_tokenizer_files(
|
tokenizer_path = modify_tokenizer_files(
|
||||||
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
|
tokenizer_path, cfg.added_tokens_overrides, **modify_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer = tokenizer_cls.from_pretrained(
|
tokenizer = tokenizer_cls.from_pretrained(
|
||||||
|
|||||||
135
tests/test_revision_parameter.py
Normal file
135
tests/test_revision_parameter.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""Tests for revision_of_model being passed to tokenizer and processor loaders."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
class TestRevisionParameter:
|
||||||
|
"""Tests for revision_of_model being passed to tokenizer and processor loaders."""
|
||||||
|
|
||||||
|
@patch("axolotl.loaders.tokenizer.load_model_config")
|
||||||
|
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
||||||
|
@patch(
|
||||||
|
"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches"
|
||||||
|
)
|
||||||
|
def test_load_tokenizer_passes_revision(
|
||||||
|
self, _mock_patches, mock_auto_tokenizer, _mock_load_config
|
||||||
|
):
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.__class__.__name__ = "MockTokenizer"
|
||||||
|
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "some-model",
|
||||||
|
"revision_of_model": "abc123",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
from axolotl.loaders.tokenizer import load_tokenizer
|
||||||
|
|
||||||
|
load_tokenizer(cfg)
|
||||||
|
|
||||||
|
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
||||||
|
assert call_kwargs.kwargs.get("revision") == "abc123"
|
||||||
|
|
||||||
|
@patch("axolotl.loaders.tokenizer.load_model_config")
|
||||||
|
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
||||||
|
@patch(
|
||||||
|
"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches"
|
||||||
|
)
|
||||||
|
def test_load_tokenizer_omits_revision_when_unset(
|
||||||
|
self, _mock_patches, mock_auto_tokenizer, _mock_load_config
|
||||||
|
):
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.__class__.__name__ = "MockTokenizer"
|
||||||
|
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "some-model",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
from axolotl.loaders.tokenizer import load_tokenizer
|
||||||
|
|
||||||
|
load_tokenizer(cfg)
|
||||||
|
|
||||||
|
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
||||||
|
assert "revision" not in call_kwargs.kwargs
|
||||||
|
|
||||||
|
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
||||||
|
@patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True)
|
||||||
|
@patch("axolotl.loaders.tokenizer.barrier")
|
||||||
|
def test_modify_tokenizer_files_passes_revision(
|
||||||
|
self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir
|
||||||
|
):
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
from axolotl.loaders.tokenizer import modify_tokenizer_files
|
||||||
|
|
||||||
|
modify_tokenizer_files("some-model", {}, output_dir=temp_dir, revision="abc123")
|
||||||
|
|
||||||
|
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
||||||
|
assert call_kwargs.kwargs.get("revision") == "abc123"
|
||||||
|
|
||||||
|
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
||||||
|
@patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True)
|
||||||
|
@patch("axolotl.loaders.tokenizer.barrier")
|
||||||
|
def test_modify_tokenizer_files_defaults_revision_to_main(
|
||||||
|
self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir
|
||||||
|
):
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
from axolotl.loaders.tokenizer import modify_tokenizer_files
|
||||||
|
|
||||||
|
modify_tokenizer_files("some-model", {}, output_dir=temp_dir)
|
||||||
|
|
||||||
|
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
||||||
|
assert call_kwargs.kwargs.get("revision") == "main"
|
||||||
|
|
||||||
|
@patch("axolotl.loaders.processor.AutoProcessor")
|
||||||
|
def test_load_processor_passes_revision(self, mock_auto_processor):
|
||||||
|
mock_processor = MagicMock()
|
||||||
|
mock_processor.size = {}
|
||||||
|
mock_auto_processor.from_pretrained.return_value = mock_processor
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"processor_config": "some-model",
|
||||||
|
"revision_of_model": "abc123",
|
||||||
|
"trust_remote_code": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||||
|
|
||||||
|
from axolotl.loaders.processor import load_processor
|
||||||
|
|
||||||
|
load_processor(cfg, tokenizer)
|
||||||
|
|
||||||
|
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||||
|
assert call_kwargs.kwargs.get("revision") == "abc123"
|
||||||
|
|
||||||
|
@patch("axolotl.loaders.processor.AutoProcessor")
|
||||||
|
def test_load_processor_omits_revision_when_unset(self, mock_auto_processor):
|
||||||
|
mock_processor = MagicMock()
|
||||||
|
mock_processor.size = {}
|
||||||
|
mock_auto_processor.from_pretrained.return_value = mock_processor
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"processor_config": "some-model",
|
||||||
|
"trust_remote_code": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||||
|
|
||||||
|
from axolotl.loaders.processor import load_processor
|
||||||
|
|
||||||
|
load_processor(cfg, tokenizer)
|
||||||
|
|
||||||
|
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||||
|
assert "revision" not in call_kwargs.kwargs
|
||||||
Reference in New Issue
Block a user