diff --git a/src/axolotl/loaders/processor.py b/src/axolotl/loaders/processor.py index 124dad39e..e07e324d6 100644 --- a/src/axolotl/loaders/processor.py +++ b/src/axolotl/loaders/processor.py @@ -19,6 +19,11 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): if cfg.processor_type: processor_cls = getattr(transformers, cfg.processor_type) + # Build common kwargs for processor loading + processor_kwargs = {} + if cfg.revision_of_model: + processor_kwargs["revision"] = cfg.revision_of_model + if cfg.tokenizer_use_mistral_common: def _patch_mistralcommontokenizer(): @@ -40,6 +45,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): if processor_cls == VoxtralProcessor: return VoxtralProcessor.from_pretrained( cfg.processor_config, + **processor_kwargs, ) from axolotl.utils.mistral import Mistral3Processor @@ -48,10 +54,12 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): tokenizer=tokenizer, ) + processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False + processor_kwargs["tokenizer"] = tokenizer + processor = processor_cls.from_pretrained( cfg.processor_config, - trust_remote_code=cfg.trust_remote_code or False, - tokenizer=tokenizer, + **processor_kwargs, ) # Attempt to load image size from processor if available diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 170ebf333..d45d23bae 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -28,7 +28,10 @@ PLUGIN_MANAGER = PluginManager.get_instance() def modify_tokenizer_files( - tokenizer_path: str, token_mappings: dict[int, str], output_dir: str + tokenizer_path: str, + token_mappings: dict[int, str], + output_dir: str, + revision: str = "main", ) -> str: """ Modify tokenizer files to replace added_tokens strings, save to output directory, @@ -41,6 +44,7 @@ def modify_tokenizer_files( tokenizer_path: Path or name of the original tokenizer token_mappings: Dict mapping {token_id (int): new_token_string} output_dir: Directory to save the modified tokenizer + revision: Model revision/branch/tag/commit to load from (HF Hub) Returns: Path to the modified tokenizer directory @@ -53,7 +57,9 @@ def modify_tokenizer_files( if is_local_main_process(): # Load the tokenizer - temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) + temp_tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, use_fast=True, revision=revision + ) # Save the tokenizer to the output directory temp_tokenizer.save_pretrained(tokenizer_dir) @@ -134,7 +140,10 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: from axolotl.utils.mistral import HFMistralTokenizer # Load the HF-compatible wrapper around MistralTokenizer - tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config) + kwargs = {} + if cfg.revision_of_model: + kwargs["revision"] = cfg.revision_of_model + tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config, **kwargs) return tokenizer @@ -150,6 +159,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: if cfg.tokenizer_legacy is not None: # True is the default w/ https://github.com/huggingface/transformers/pull/25224 tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy + if cfg.revision_of_model: + tokenizer_kwargs["revision"] = cfg.revision_of_model tokenizer_cls = AutoTokenizer if cfg.tokenizer_type: @@ -161,8 +172,11 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: # Apply token string overrides if specified if cfg.added_tokens_overrides: # Modify tokenizer files and get path to modified tokenizer + modify_kwargs = {"output_dir": cfg.output_dir} + if cfg.revision_of_model: + modify_kwargs["revision"] = cfg.revision_of_model tokenizer_path = modify_tokenizer_files( - tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir + tokenizer_path, cfg.added_tokens_overrides, **modify_kwargs ) tokenizer = tokenizer_cls.from_pretrained( diff --git a/tests/test_revision_parameter.py b/tests/test_revision_parameter.py new file mode 100644 index 000000000..2116b223f --- /dev/null +++ b/tests/test_revision_parameter.py @@ -0,0 +1,135 @@ +"""Tests for revision_of_model being passed to tokenizer and processor loaders.""" + +from unittest.mock import MagicMock, patch + +from transformers import PreTrainedTokenizerBase + +from axolotl.utils.dict import DictDefault + + +class TestRevisionParameter: + """Tests for revision_of_model being passed to tokenizer and processor loaders.""" + + @patch("axolotl.loaders.tokenizer.load_model_config") + @patch("axolotl.loaders.tokenizer.AutoTokenizer") + @patch( + "axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches" + ) + def test_load_tokenizer_passes_revision( + self, _mock_patches, mock_auto_tokenizer, _mock_load_config + ): + mock_tokenizer = MagicMock() + mock_tokenizer.__class__.__name__ = "MockTokenizer" + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + cfg = DictDefault( + { + "tokenizer_config": "some-model", + "revision_of_model": "abc123", + } + ) + from axolotl.loaders.tokenizer import load_tokenizer + + load_tokenizer(cfg) + + call_kwargs = mock_auto_tokenizer.from_pretrained.call_args + assert call_kwargs.kwargs.get("revision") == "abc123" + + @patch("axolotl.loaders.tokenizer.load_model_config") + @patch("axolotl.loaders.tokenizer.AutoTokenizer") + @patch( + "axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches" + ) + def test_load_tokenizer_omits_revision_when_unset( + self, _mock_patches, mock_auto_tokenizer, _mock_load_config + ): + mock_tokenizer = MagicMock() + mock_tokenizer.__class__.__name__ = "MockTokenizer" + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + cfg = DictDefault( + { + "tokenizer_config": "some-model", + } + ) + from axolotl.loaders.tokenizer import load_tokenizer + + load_tokenizer(cfg) + + call_kwargs = mock_auto_tokenizer.from_pretrained.call_args + assert "revision" not in call_kwargs.kwargs + + @patch("axolotl.loaders.tokenizer.AutoTokenizer") + @patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True) + @patch("axolotl.loaders.tokenizer.barrier") + def test_modify_tokenizer_files_passes_revision( + self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir + ): + mock_tokenizer = MagicMock() + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + from axolotl.loaders.tokenizer import modify_tokenizer_files + + modify_tokenizer_files("some-model", {}, output_dir=temp_dir, revision="abc123") + + call_kwargs = mock_auto_tokenizer.from_pretrained.call_args + assert call_kwargs.kwargs.get("revision") == "abc123" + + @patch("axolotl.loaders.tokenizer.AutoTokenizer") + @patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True) + @patch("axolotl.loaders.tokenizer.barrier") + def test_modify_tokenizer_files_defaults_revision_to_main( + self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir + ): + mock_tokenizer = MagicMock() + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + from axolotl.loaders.tokenizer import modify_tokenizer_files + + modify_tokenizer_files("some-model", {}, output_dir=temp_dir) + + call_kwargs = mock_auto_tokenizer.from_pretrained.call_args + assert call_kwargs.kwargs.get("revision") == "main" + + @patch("axolotl.loaders.processor.AutoProcessor") + def test_load_processor_passes_revision(self, mock_auto_processor): + mock_processor = MagicMock() + mock_processor.size = {} + mock_auto_processor.from_pretrained.return_value = mock_processor + + cfg = DictDefault( + { + "processor_config": "some-model", + "revision_of_model": "abc123", + "trust_remote_code": False, + } + ) + tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + + from axolotl.loaders.processor import load_processor + + load_processor(cfg, tokenizer) + + call_kwargs = mock_auto_processor.from_pretrained.call_args + assert call_kwargs.kwargs.get("revision") == "abc123" + + @patch("axolotl.loaders.processor.AutoProcessor") + def test_load_processor_omits_revision_when_unset(self, mock_auto_processor): + mock_processor = MagicMock() + mock_processor.size = {} + mock_auto_processor.from_pretrained.return_value = mock_processor + + cfg = DictDefault( + { + "processor_config": "some-model", + "trust_remote_code": False, + } + ) + tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + + from axolotl.loaders.processor import load_processor + + load_processor(cfg, tokenizer) + + call_kwargs = mock_auto_processor.from_pretrained.call_args + assert "revision" not in call_kwargs.kwargs