diff --git a/src/axolotl/loaders/processor.py b/src/axolotl/loaders/processor.py index 211c26060..4cb03b84d 100644 --- a/src/axolotl/loaders/processor.py +++ b/src/axolotl/loaders/processor.py @@ -23,6 +23,8 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): processor_kwargs = {} if cfg.revision_of_model: processor_kwargs["revision"] = cfg.revision_of_model + if cfg.processor_kwargs: + processor_kwargs.update(cfg.processor_kwargs) if cfg.tokenizer_use_mistral_common: diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index 3c5dfc6e3..f54958b33 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -64,6 +64,12 @@ class ModelInputConfig(BaseModel): processor_type: str | None = Field( default=None, json_schema_extra={"description": "transformers processor class"} ) + processor_kwargs: dict[str, Any] | None = Field( + default=None, + json_schema_extra={ + "description": "kwargs forwarded to the processor's from_pretrained(), overriding processor config (e.g. image_seq_length, min_pixels, etc.)." + }, + ) tokenizer_save_jinja_files: bool | None = Field( default=True, # match the default behavior from transformers json_schema_extra={ @@ -107,6 +113,22 @@ class ModelInputConfig(BaseModel): ) return trust_remote_code + @field_validator("processor_kwargs") + @classmethod + def reject_reserved_processor_kwargs(cls, processor_kwargs): + if not processor_kwargs: + return processor_kwargs + reserved = {"revision", "trust_remote_code"} + conflicts = reserved.intersection(processor_kwargs) + if conflicts: + raise ValueError( + "Do not set reserved keys " + f"{sorted(conflicts)} inside `processor_kwargs`; " + "use the top-level `revision_of_model` / `trust_remote_code` " + "config keys instead." + ) + return processor_kwargs + class ModelOutputConfig(BaseModel): """model save configuration subset""" diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 9161765d0..fff69de26 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -578,6 +578,11 @@ class TrainingValidationMixin: "Setting chat_template is not supported with mistral-common tokenizer" ) + if data.get("processor_kwargs"): + raise ValueError( + "processor_kwargs is not supported with mistral-common tokenizer" + ) + return data @model_validator(mode="before") diff --git a/tests/test_revision_parameter.py b/tests/test_revision_parameter.py index 2116b223f..112badfb8 100644 --- a/tests/test_revision_parameter.py +++ b/tests/test_revision_parameter.py @@ -133,3 +133,108 @@ class TestRevisionParameter: call_kwargs = mock_auto_processor.from_pretrained.call_args assert "revision" not in call_kwargs.kwargs + + @patch("axolotl.loaders.processor.AutoProcessor") + def test_load_processor_forwards_processor_kwargs(self, mock_auto_processor): + mock_processor = MagicMock() + mock_processor.size = {} + mock_auto_processor.from_pretrained.return_value = mock_processor + + cfg = DictDefault( + { + "processor_config": "some-model", + "trust_remote_code": False, + "processor_kwargs": { + "image_seq_length": 1120, + "max_soft_tokens": 1120, + }, + } + ) + tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + + from axolotl.loaders.processor import load_processor + + load_processor(cfg, tokenizer) + + call_kwargs = mock_auto_processor.from_pretrained.call_args + assert call_kwargs.kwargs.get("image_seq_length") == 1120 + assert call_kwargs.kwargs.get("max_soft_tokens") == 1120 + + @patch("axolotl.loaders.processor.AutoProcessor") + def test_load_processor_omits_processor_kwargs_when_unset( + self, mock_auto_processor + ): + mock_processor = MagicMock() + mock_processor.size = {} + mock_auto_processor.from_pretrained.return_value = mock_processor + + cfg = DictDefault( + { + "processor_config": "some-model", + "trust_remote_code": False, + } + ) + tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + + from axolotl.loaders.processor import load_processor + + load_processor(cfg, tokenizer) + + call_kwargs = mock_auto_processor.from_pretrained.call_args + assert "image_seq_length" not in call_kwargs.kwargs + assert "max_soft_tokens" not in call_kwargs.kwargs + + def test_processor_kwargs_schema_rejects_revision(self): + import pytest + + from axolotl.utils.schemas.model import ModelInputConfig + + with pytest.raises(ValueError, match="revision"): + ModelInputConfig( + base_model="some-model", + processor_kwargs={"revision": "abc123"}, + ) + + def test_processor_kwargs_schema_rejects_trust_remote_code(self): + import pytest + + from axolotl.utils.schemas.model import ModelInputConfig + + with pytest.raises(ValueError, match="trust_remote_code"): + ModelInputConfig( + base_model="some-model", + processor_kwargs={"trust_remote_code": True}, + ) + + def test_processor_kwargs_schema_accepts_valid_keys(self): + from axolotl.utils.schemas.model import ModelInputConfig + + cfg = ModelInputConfig( + base_model="some-model", + processor_kwargs={"image_seq_length": 1120, "max_soft_tokens": 1120}, + ) + assert cfg.processor_kwargs == { + "image_seq_length": 1120, + "max_soft_tokens": 1120, + } + + def test_processor_kwargs_schema_accepts_none_and_empty(self): + from axolotl.utils.schemas.model import ModelInputConfig + + assert ModelInputConfig(base_model="x").processor_kwargs is None + assert ( + ModelInputConfig(base_model="x", processor_kwargs={}).processor_kwargs == {} + ) + + def test_processor_kwargs_incompatible_with_mistral_common(self, min_base_cfg): + import pytest + + from axolotl.utils.config import validate_config + from axolotl.utils.dict import DictDefault + + cfg = min_base_cfg | DictDefault( + tokenizer_use_mistral_common=True, + processor_kwargs={"image_seq_length": 1120}, + ) + with pytest.raises(ValueError, match="processor_kwargs"): + validate_config(cfg)