feat: add processor_kwargs YAML field forwarded to from_pretrained (#3612)

This commit is contained in:
thad0ctor
2026-04-22 21:26:34 -07:00
committed by GitHub
parent bcbe049c21
commit 1bf65c500e
4 changed files with 134 additions and 0 deletions

View File

@@ -133,3 +133,108 @@ class TestRevisionParameter:
call_kwargs = mock_auto_processor.from_pretrained.call_args
assert "revision" not in call_kwargs.kwargs
@patch("axolotl.loaders.processor.AutoProcessor")
def test_load_processor_forwards_processor_kwargs(self, mock_auto_processor):
mock_processor = MagicMock()
mock_processor.size = {}
mock_auto_processor.from_pretrained.return_value = mock_processor
cfg = DictDefault(
{
"processor_config": "some-model",
"trust_remote_code": False,
"processor_kwargs": {
"image_seq_length": 1120,
"max_soft_tokens": 1120,
},
}
)
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
from axolotl.loaders.processor import load_processor
load_processor(cfg, tokenizer)
call_kwargs = mock_auto_processor.from_pretrained.call_args
assert call_kwargs.kwargs.get("image_seq_length") == 1120
assert call_kwargs.kwargs.get("max_soft_tokens") == 1120
@patch("axolotl.loaders.processor.AutoProcessor")
def test_load_processor_omits_processor_kwargs_when_unset(
self, mock_auto_processor
):
mock_processor = MagicMock()
mock_processor.size = {}
mock_auto_processor.from_pretrained.return_value = mock_processor
cfg = DictDefault(
{
"processor_config": "some-model",
"trust_remote_code": False,
}
)
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
from axolotl.loaders.processor import load_processor
load_processor(cfg, tokenizer)
call_kwargs = mock_auto_processor.from_pretrained.call_args
assert "image_seq_length" not in call_kwargs.kwargs
assert "max_soft_tokens" not in call_kwargs.kwargs
def test_processor_kwargs_schema_rejects_revision(self):
import pytest
from axolotl.utils.schemas.model import ModelInputConfig
with pytest.raises(ValueError, match="revision"):
ModelInputConfig(
base_model="some-model",
processor_kwargs={"revision": "abc123"},
)
def test_processor_kwargs_schema_rejects_trust_remote_code(self):
import pytest
from axolotl.utils.schemas.model import ModelInputConfig
with pytest.raises(ValueError, match="trust_remote_code"):
ModelInputConfig(
base_model="some-model",
processor_kwargs={"trust_remote_code": True},
)
def test_processor_kwargs_schema_accepts_valid_keys(self):
from axolotl.utils.schemas.model import ModelInputConfig
cfg = ModelInputConfig(
base_model="some-model",
processor_kwargs={"image_seq_length": 1120, "max_soft_tokens": 1120},
)
assert cfg.processor_kwargs == {
"image_seq_length": 1120,
"max_soft_tokens": 1120,
}
def test_processor_kwargs_schema_accepts_none_and_empty(self):
from axolotl.utils.schemas.model import ModelInputConfig
assert ModelInputConfig(base_model="x").processor_kwargs is None
assert (
ModelInputConfig(base_model="x", processor_kwargs={}).processor_kwargs == {}
)
def test_processor_kwargs_incompatible_with_mistral_common(self, min_base_cfg):
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
cfg = min_base_cfg | DictDefault(
tokenizer_use_mistral_common=True,
processor_kwargs={"image_seq_length": 1120},
)
with pytest.raises(ValueError, match="processor_kwargs"):
validate_config(cfg)