fix: pass revision parameter to tokenizer and processor loaders (#3388) [skip ci]
* fix: pass revision parameter to tokenizer and processor loaders * fix: address revision=None passed to .from_pretrained * add tests and address review feedback for revision parameter - Reformat modify_tokenizer_files signature and from_pretrained call - Use kwargs pattern for modify_tokenizer_files call to avoid passing None revision - Add 6 unit tests for revision parameter in tokenizer/processor loaders --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
135
tests/test_revision_parameter.py
Normal file
135
tests/test_revision_parameter.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Tests for revision_of_model being passed to tokenizer and processor loaders."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestRevisionParameter:
|
||||
"""Tests for revision_of_model being passed to tokenizer and processor loaders."""
|
||||
|
||||
@patch("axolotl.loaders.tokenizer.load_model_config")
|
||||
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
||||
@patch(
|
||||
"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches"
|
||||
)
|
||||
def test_load_tokenizer_passes_revision(
|
||||
self, _mock_patches, mock_auto_tokenizer, _mock_load_config
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.__class__.__name__ = "MockTokenizer"
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "some-model",
|
||||
"revision_of_model": "abc123",
|
||||
}
|
||||
)
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
|
||||
load_tokenizer(cfg)
|
||||
|
||||
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
||||
assert call_kwargs.kwargs.get("revision") == "abc123"
|
||||
|
||||
@patch("axolotl.loaders.tokenizer.load_model_config")
|
||||
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
||||
@patch(
|
||||
"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches"
|
||||
)
|
||||
def test_load_tokenizer_omits_revision_when_unset(
|
||||
self, _mock_patches, mock_auto_tokenizer, _mock_load_config
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.__class__.__name__ = "MockTokenizer"
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "some-model",
|
||||
}
|
||||
)
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
|
||||
load_tokenizer(cfg)
|
||||
|
||||
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
||||
assert "revision" not in call_kwargs.kwargs
|
||||
|
||||
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
||||
@patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True)
|
||||
@patch("axolotl.loaders.tokenizer.barrier")
|
||||
def test_modify_tokenizer_files_passes_revision(
|
||||
self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
from axolotl.loaders.tokenizer import modify_tokenizer_files
|
||||
|
||||
modify_tokenizer_files("some-model", {}, output_dir=temp_dir, revision="abc123")
|
||||
|
||||
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
||||
assert call_kwargs.kwargs.get("revision") == "abc123"
|
||||
|
||||
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
||||
@patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True)
|
||||
@patch("axolotl.loaders.tokenizer.barrier")
|
||||
def test_modify_tokenizer_files_defaults_revision_to_main(
|
||||
self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
from axolotl.loaders.tokenizer import modify_tokenizer_files
|
||||
|
||||
modify_tokenizer_files("some-model", {}, output_dir=temp_dir)
|
||||
|
||||
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
||||
assert call_kwargs.kwargs.get("revision") == "main"
|
||||
|
||||
@patch("axolotl.loaders.processor.AutoProcessor")
|
||||
def test_load_processor_passes_revision(self, mock_auto_processor):
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.size = {}
|
||||
mock_auto_processor.from_pretrained.return_value = mock_processor
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"processor_config": "some-model",
|
||||
"revision_of_model": "abc123",
|
||||
"trust_remote_code": False,
|
||||
}
|
||||
)
|
||||
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||
|
||||
from axolotl.loaders.processor import load_processor
|
||||
|
||||
load_processor(cfg, tokenizer)
|
||||
|
||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||
assert call_kwargs.kwargs.get("revision") == "abc123"
|
||||
|
||||
@patch("axolotl.loaders.processor.AutoProcessor")
|
||||
def test_load_processor_omits_revision_when_unset(self, mock_auto_processor):
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.size = {}
|
||||
mock_auto_processor.from_pretrained.return_value = mock_processor
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"processor_config": "some-model",
|
||||
"trust_remote_code": False,
|
||||
}
|
||||
)
|
||||
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||
|
||||
from axolotl.loaders.processor import load_processor
|
||||
|
||||
load_processor(cfg, tokenizer)
|
||||
|
||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||
assert "revision" not in call_kwargs.kwargs
|
||||
Reference in New Issue
Block a user