feat: add processor_kwargs YAML field forwarded to from_pretrained (#3612)

This commit is contained in:
thad0ctor
2026-04-22 21:26:34 -07:00
committed by GitHub
parent bcbe049c21
commit 1bf65c500e
4 changed files with 134 additions and 0 deletions

View File

@@ -23,6 +23,8 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs = {}
if cfg.revision_of_model:
processor_kwargs["revision"] = cfg.revision_of_model
if cfg.processor_kwargs:
processor_kwargs.update(cfg.processor_kwargs)
if cfg.tokenizer_use_mistral_common:

View File

@@ -64,6 +64,12 @@ class ModelInputConfig(BaseModel):
processor_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers processor class"}
)
processor_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "kwargs forwarded to the processor's from_pretrained(), overriding processor config (e.g. image_seq_length, min_pixels, etc.)."
},
)
tokenizer_save_jinja_files: bool | None = Field(
default=True, # match the default behavior from transformers
json_schema_extra={
@@ -107,6 +113,22 @@ class ModelInputConfig(BaseModel):
)
return trust_remote_code
@field_validator("processor_kwargs")
@classmethod
def reject_reserved_processor_kwargs(cls, processor_kwargs):
if not processor_kwargs:
return processor_kwargs
reserved = {"revision", "trust_remote_code"}
conflicts = reserved.intersection(processor_kwargs)
if conflicts:
raise ValueError(
"Do not set reserved keys "
f"{sorted(conflicts)} inside `processor_kwargs`; "
"use the top-level `revision_of_model` / `trust_remote_code` "
"config keys instead."
)
return processor_kwargs
class ModelOutputConfig(BaseModel):
"""model save configuration subset"""

View File

@@ -578,6 +578,11 @@ class TrainingValidationMixin:
"Setting chat_template is not supported with mistral-common tokenizer"
)
if data.get("processor_kwargs"):
raise ValueError(
"processor_kwargs is not supported with mistral-common tokenizer"
)
return data
@model_validator(mode="before")

View File

@@ -133,3 +133,108 @@ class TestRevisionParameter:
call_kwargs = mock_auto_processor.from_pretrained.call_args
assert "revision" not in call_kwargs.kwargs
@patch("axolotl.loaders.processor.AutoProcessor")
def test_load_processor_forwards_processor_kwargs(self, mock_auto_processor):
mock_processor = MagicMock()
mock_processor.size = {}
mock_auto_processor.from_pretrained.return_value = mock_processor
cfg = DictDefault(
{
"processor_config": "some-model",
"trust_remote_code": False,
"processor_kwargs": {
"image_seq_length": 1120,
"max_soft_tokens": 1120,
},
}
)
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
from axolotl.loaders.processor import load_processor
load_processor(cfg, tokenizer)
call_kwargs = mock_auto_processor.from_pretrained.call_args
assert call_kwargs.kwargs.get("image_seq_length") == 1120
assert call_kwargs.kwargs.get("max_soft_tokens") == 1120
@patch("axolotl.loaders.processor.AutoProcessor")
def test_load_processor_omits_processor_kwargs_when_unset(
self, mock_auto_processor
):
mock_processor = MagicMock()
mock_processor.size = {}
mock_auto_processor.from_pretrained.return_value = mock_processor
cfg = DictDefault(
{
"processor_config": "some-model",
"trust_remote_code": False,
}
)
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
from axolotl.loaders.processor import load_processor
load_processor(cfg, tokenizer)
call_kwargs = mock_auto_processor.from_pretrained.call_args
assert "image_seq_length" not in call_kwargs.kwargs
assert "max_soft_tokens" not in call_kwargs.kwargs
def test_processor_kwargs_schema_rejects_revision(self):
import pytest
from axolotl.utils.schemas.model import ModelInputConfig
with pytest.raises(ValueError, match="revision"):
ModelInputConfig(
base_model="some-model",
processor_kwargs={"revision": "abc123"},
)
def test_processor_kwargs_schema_rejects_trust_remote_code(self):
import pytest
from axolotl.utils.schemas.model import ModelInputConfig
with pytest.raises(ValueError, match="trust_remote_code"):
ModelInputConfig(
base_model="some-model",
processor_kwargs={"trust_remote_code": True},
)
def test_processor_kwargs_schema_accepts_valid_keys(self):
from axolotl.utils.schemas.model import ModelInputConfig
cfg = ModelInputConfig(
base_model="some-model",
processor_kwargs={"image_seq_length": 1120, "max_soft_tokens": 1120},
)
assert cfg.processor_kwargs == {
"image_seq_length": 1120,
"max_soft_tokens": 1120,
}
def test_processor_kwargs_schema_accepts_none_and_empty(self):
from axolotl.utils.schemas.model import ModelInputConfig
assert ModelInputConfig(base_model="x").processor_kwargs is None
assert (
ModelInputConfig(base_model="x", processor_kwargs={}).processor_kwargs == {}
)
def test_processor_kwargs_incompatible_with_mistral_common(self, min_base_cfg):
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
cfg = min_base_cfg | DictDefault(
tokenizer_use_mistral_common=True,
processor_kwargs={"image_seq_length": 1120},
)
with pytest.raises(ValueError, match="processor_kwargs"):
validate_config(cfg)