feat: add processor_kwargs YAML field forwarded to from_pretrained (#3612)
This commit is contained in:
@@ -133,3 +133,108 @@ class TestRevisionParameter:
|
||||
|
||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||
assert "revision" not in call_kwargs.kwargs
|
||||
|
||||
@patch("axolotl.loaders.processor.AutoProcessor")
|
||||
def test_load_processor_forwards_processor_kwargs(self, mock_auto_processor):
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.size = {}
|
||||
mock_auto_processor.from_pretrained.return_value = mock_processor
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"processor_config": "some-model",
|
||||
"trust_remote_code": False,
|
||||
"processor_kwargs": {
|
||||
"image_seq_length": 1120,
|
||||
"max_soft_tokens": 1120,
|
||||
},
|
||||
}
|
||||
)
|
||||
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||
|
||||
from axolotl.loaders.processor import load_processor
|
||||
|
||||
load_processor(cfg, tokenizer)
|
||||
|
||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||
assert call_kwargs.kwargs.get("image_seq_length") == 1120
|
||||
assert call_kwargs.kwargs.get("max_soft_tokens") == 1120
|
||||
|
||||
@patch("axolotl.loaders.processor.AutoProcessor")
|
||||
def test_load_processor_omits_processor_kwargs_when_unset(
|
||||
self, mock_auto_processor
|
||||
):
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.size = {}
|
||||
mock_auto_processor.from_pretrained.return_value = mock_processor
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"processor_config": "some-model",
|
||||
"trust_remote_code": False,
|
||||
}
|
||||
)
|
||||
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||
|
||||
from axolotl.loaders.processor import load_processor
|
||||
|
||||
load_processor(cfg, tokenizer)
|
||||
|
||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||
assert "image_seq_length" not in call_kwargs.kwargs
|
||||
assert "max_soft_tokens" not in call_kwargs.kwargs
|
||||
|
||||
def test_processor_kwargs_schema_rejects_revision(self):
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.schemas.model import ModelInputConfig
|
||||
|
||||
with pytest.raises(ValueError, match="revision"):
|
||||
ModelInputConfig(
|
||||
base_model="some-model",
|
||||
processor_kwargs={"revision": "abc123"},
|
||||
)
|
||||
|
||||
def test_processor_kwargs_schema_rejects_trust_remote_code(self):
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.schemas.model import ModelInputConfig
|
||||
|
||||
with pytest.raises(ValueError, match="trust_remote_code"):
|
||||
ModelInputConfig(
|
||||
base_model="some-model",
|
||||
processor_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
def test_processor_kwargs_schema_accepts_valid_keys(self):
|
||||
from axolotl.utils.schemas.model import ModelInputConfig
|
||||
|
||||
cfg = ModelInputConfig(
|
||||
base_model="some-model",
|
||||
processor_kwargs={"image_seq_length": 1120, "max_soft_tokens": 1120},
|
||||
)
|
||||
assert cfg.processor_kwargs == {
|
||||
"image_seq_length": 1120,
|
||||
"max_soft_tokens": 1120,
|
||||
}
|
||||
|
||||
def test_processor_kwargs_schema_accepts_none_and_empty(self):
|
||||
from axolotl.utils.schemas.model import ModelInputConfig
|
||||
|
||||
assert ModelInputConfig(base_model="x").processor_kwargs is None
|
||||
assert (
|
||||
ModelInputConfig(base_model="x", processor_kwargs={}).processor_kwargs == {}
|
||||
)
|
||||
|
||||
def test_processor_kwargs_incompatible_with_mistral_common(self, min_base_cfg):
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
tokenizer_use_mistral_common=True,
|
||||
processor_kwargs={"image_seq_length": 1120},
|
||||
)
|
||||
with pytest.raises(ValueError, match="processor_kwargs"):
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user