241 lines
8.5 KiB
Python
241 lines
8.5 KiB
Python
"""Tests for revision_of_model being passed to tokenizer and processor loaders."""
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from transformers import PreTrainedTokenizerBase
|
|
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
|
|
class TestRevisionParameter:
|
|
"""Tests for revision_of_model being passed to tokenizer and processor loaders."""
|
|
|
|
@patch("axolotl.loaders.tokenizer.load_model_config")
|
|
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
|
@patch(
|
|
"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches"
|
|
)
|
|
def test_load_tokenizer_passes_revision(
|
|
self, _mock_patches, mock_auto_tokenizer, _mock_load_config
|
|
):
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.__class__.__name__ = "MockTokenizer"
|
|
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "some-model",
|
|
"revision_of_model": "abc123",
|
|
}
|
|
)
|
|
from axolotl.loaders.tokenizer import load_tokenizer
|
|
|
|
load_tokenizer(cfg)
|
|
|
|
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
|
assert call_kwargs.kwargs.get("revision") == "abc123"
|
|
|
|
@patch("axolotl.loaders.tokenizer.load_model_config")
|
|
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
|
@patch(
|
|
"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches"
|
|
)
|
|
def test_load_tokenizer_omits_revision_when_unset(
|
|
self, _mock_patches, mock_auto_tokenizer, _mock_load_config
|
|
):
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.__class__.__name__ = "MockTokenizer"
|
|
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "some-model",
|
|
}
|
|
)
|
|
from axolotl.loaders.tokenizer import load_tokenizer
|
|
|
|
load_tokenizer(cfg)
|
|
|
|
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
|
assert "revision" not in call_kwargs.kwargs
|
|
|
|
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
|
@patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True)
|
|
@patch("axolotl.loaders.tokenizer.barrier")
|
|
def test_modify_tokenizer_files_passes_revision(
|
|
self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir
|
|
):
|
|
mock_tokenizer = MagicMock()
|
|
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
|
|
|
from axolotl.loaders.tokenizer import modify_tokenizer_files
|
|
|
|
modify_tokenizer_files("some-model", {}, output_dir=temp_dir, revision="abc123")
|
|
|
|
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
|
assert call_kwargs.kwargs.get("revision") == "abc123"
|
|
|
|
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
|
|
@patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True)
|
|
@patch("axolotl.loaders.tokenizer.barrier")
|
|
def test_modify_tokenizer_files_defaults_revision_to_main(
|
|
self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir
|
|
):
|
|
mock_tokenizer = MagicMock()
|
|
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
|
|
|
from axolotl.loaders.tokenizer import modify_tokenizer_files
|
|
|
|
modify_tokenizer_files("some-model", {}, output_dir=temp_dir)
|
|
|
|
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
|
|
assert call_kwargs.kwargs.get("revision") == "main"
|
|
|
|
@patch("axolotl.loaders.processor.AutoProcessor")
|
|
def test_load_processor_passes_revision(self, mock_auto_processor):
|
|
mock_processor = MagicMock()
|
|
mock_processor.size = {}
|
|
mock_auto_processor.from_pretrained.return_value = mock_processor
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"processor_config": "some-model",
|
|
"revision_of_model": "abc123",
|
|
"trust_remote_code": False,
|
|
}
|
|
)
|
|
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
|
|
|
from axolotl.loaders.processor import load_processor
|
|
|
|
load_processor(cfg, tokenizer)
|
|
|
|
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
|
assert call_kwargs.kwargs.get("revision") == "abc123"
|
|
|
|
@patch("axolotl.loaders.processor.AutoProcessor")
|
|
def test_load_processor_omits_revision_when_unset(self, mock_auto_processor):
|
|
mock_processor = MagicMock()
|
|
mock_processor.size = {}
|
|
mock_auto_processor.from_pretrained.return_value = mock_processor
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"processor_config": "some-model",
|
|
"trust_remote_code": False,
|
|
}
|
|
)
|
|
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
|
|
|
from axolotl.loaders.processor import load_processor
|
|
|
|
load_processor(cfg, tokenizer)
|
|
|
|
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
|
assert "revision" not in call_kwargs.kwargs
|
|
|
|
@patch("axolotl.loaders.processor.AutoProcessor")
|
|
def test_load_processor_forwards_processor_kwargs(self, mock_auto_processor):
|
|
mock_processor = MagicMock()
|
|
mock_processor.size = {}
|
|
mock_auto_processor.from_pretrained.return_value = mock_processor
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"processor_config": "some-model",
|
|
"trust_remote_code": False,
|
|
"processor_kwargs": {
|
|
"image_seq_length": 1120,
|
|
"max_soft_tokens": 1120,
|
|
},
|
|
}
|
|
)
|
|
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
|
|
|
from axolotl.loaders.processor import load_processor
|
|
|
|
load_processor(cfg, tokenizer)
|
|
|
|
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
|
assert call_kwargs.kwargs.get("image_seq_length") == 1120
|
|
assert call_kwargs.kwargs.get("max_soft_tokens") == 1120
|
|
|
|
@patch("axolotl.loaders.processor.AutoProcessor")
|
|
def test_load_processor_omits_processor_kwargs_when_unset(
|
|
self, mock_auto_processor
|
|
):
|
|
mock_processor = MagicMock()
|
|
mock_processor.size = {}
|
|
mock_auto_processor.from_pretrained.return_value = mock_processor
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"processor_config": "some-model",
|
|
"trust_remote_code": False,
|
|
}
|
|
)
|
|
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
|
|
|
from axolotl.loaders.processor import load_processor
|
|
|
|
load_processor(cfg, tokenizer)
|
|
|
|
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
|
assert "image_seq_length" not in call_kwargs.kwargs
|
|
assert "max_soft_tokens" not in call_kwargs.kwargs
|
|
|
|
def test_processor_kwargs_schema_rejects_revision(self):
|
|
import pytest
|
|
|
|
from axolotl.utils.schemas.model import ModelInputConfig
|
|
|
|
with pytest.raises(ValueError, match="revision"):
|
|
ModelInputConfig(
|
|
base_model="some-model",
|
|
processor_kwargs={"revision": "abc123"},
|
|
)
|
|
|
|
def test_processor_kwargs_schema_rejects_trust_remote_code(self):
|
|
import pytest
|
|
|
|
from axolotl.utils.schemas.model import ModelInputConfig
|
|
|
|
with pytest.raises(ValueError, match="trust_remote_code"):
|
|
ModelInputConfig(
|
|
base_model="some-model",
|
|
processor_kwargs={"trust_remote_code": True},
|
|
)
|
|
|
|
def test_processor_kwargs_schema_accepts_valid_keys(self):
|
|
from axolotl.utils.schemas.model import ModelInputConfig
|
|
|
|
cfg = ModelInputConfig(
|
|
base_model="some-model",
|
|
processor_kwargs={"image_seq_length": 1120, "max_soft_tokens": 1120},
|
|
)
|
|
assert cfg.processor_kwargs == {
|
|
"image_seq_length": 1120,
|
|
"max_soft_tokens": 1120,
|
|
}
|
|
|
|
def test_processor_kwargs_schema_accepts_none_and_empty(self):
|
|
from axolotl.utils.schemas.model import ModelInputConfig
|
|
|
|
assert ModelInputConfig(base_model="x").processor_kwargs is None
|
|
assert (
|
|
ModelInputConfig(base_model="x", processor_kwargs={}).processor_kwargs == {}
|
|
)
|
|
|
|
def test_processor_kwargs_incompatible_with_mistral_common(self, min_base_cfg):
|
|
import pytest
|
|
|
|
from axolotl.utils.config import validate_config
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
cfg = min_base_cfg | DictDefault(
|
|
tokenizer_use_mistral_common=True,
|
|
processor_kwargs={"image_seq_length": 1120},
|
|
)
|
|
with pytest.raises(ValueError, match="processor_kwargs"):
|
|
validate_config(cfg)
|