feat: add processor_kwargs YAML field forwarded to from_pretrained (#3612)
This commit is contained in:
@@ -23,6 +23,8 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
processor_kwargs = {}
|
||||
if cfg.revision_of_model:
|
||||
processor_kwargs["revision"] = cfg.revision_of_model
|
||||
if cfg.processor_kwargs:
|
||||
processor_kwargs.update(cfg.processor_kwargs)
|
||||
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
|
||||
|
||||
@@ -64,6 +64,12 @@ class ModelInputConfig(BaseModel):
|
||||
processor_type: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||
)
|
||||
processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "kwargs forwarded to the processor's from_pretrained(), overriding processor config (e.g. image_seq_length, min_pixels, etc.)."
|
||||
},
|
||||
)
|
||||
tokenizer_save_jinja_files: bool | None = Field(
|
||||
default=True, # match the default behavior from transformers
|
||||
json_schema_extra={
|
||||
@@ -107,6 +113,22 @@ class ModelInputConfig(BaseModel):
|
||||
)
|
||||
return trust_remote_code
|
||||
|
||||
@field_validator("processor_kwargs")
|
||||
@classmethod
|
||||
def reject_reserved_processor_kwargs(cls, processor_kwargs):
|
||||
if not processor_kwargs:
|
||||
return processor_kwargs
|
||||
reserved = {"revision", "trust_remote_code"}
|
||||
conflicts = reserved.intersection(processor_kwargs)
|
||||
if conflicts:
|
||||
raise ValueError(
|
||||
"Do not set reserved keys "
|
||||
f"{sorted(conflicts)} inside `processor_kwargs`; "
|
||||
"use the top-level `revision_of_model` / `trust_remote_code` "
|
||||
"config keys instead."
|
||||
)
|
||||
return processor_kwargs
|
||||
|
||||
|
||||
class ModelOutputConfig(BaseModel):
|
||||
"""model save configuration subset"""
|
||||
|
||||
@@ -578,6 +578,11 @@ class TrainingValidationMixin:
|
||||
"Setting chat_template is not supported with mistral-common tokenizer"
|
||||
)
|
||||
|
||||
if data.get("processor_kwargs"):
|
||||
raise ValueError(
|
||||
"processor_kwargs is not supported with mistral-common tokenizer"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -133,3 +133,108 @@ class TestRevisionParameter:
|
||||
|
||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||
assert "revision" not in call_kwargs.kwargs
|
||||
|
||||
@patch("axolotl.loaders.processor.AutoProcessor")
|
||||
def test_load_processor_forwards_processor_kwargs(self, mock_auto_processor):
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.size = {}
|
||||
mock_auto_processor.from_pretrained.return_value = mock_processor
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"processor_config": "some-model",
|
||||
"trust_remote_code": False,
|
||||
"processor_kwargs": {
|
||||
"image_seq_length": 1120,
|
||||
"max_soft_tokens": 1120,
|
||||
},
|
||||
}
|
||||
)
|
||||
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||
|
||||
from axolotl.loaders.processor import load_processor
|
||||
|
||||
load_processor(cfg, tokenizer)
|
||||
|
||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||
assert call_kwargs.kwargs.get("image_seq_length") == 1120
|
||||
assert call_kwargs.kwargs.get("max_soft_tokens") == 1120
|
||||
|
||||
@patch("axolotl.loaders.processor.AutoProcessor")
|
||||
def test_load_processor_omits_processor_kwargs_when_unset(
|
||||
self, mock_auto_processor
|
||||
):
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.size = {}
|
||||
mock_auto_processor.from_pretrained.return_value = mock_processor
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"processor_config": "some-model",
|
||||
"trust_remote_code": False,
|
||||
}
|
||||
)
|
||||
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||
|
||||
from axolotl.loaders.processor import load_processor
|
||||
|
||||
load_processor(cfg, tokenizer)
|
||||
|
||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||
assert "image_seq_length" not in call_kwargs.kwargs
|
||||
assert "max_soft_tokens" not in call_kwargs.kwargs
|
||||
|
||||
def test_processor_kwargs_schema_rejects_revision(self):
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.schemas.model import ModelInputConfig
|
||||
|
||||
with pytest.raises(ValueError, match="revision"):
|
||||
ModelInputConfig(
|
||||
base_model="some-model",
|
||||
processor_kwargs={"revision": "abc123"},
|
||||
)
|
||||
|
||||
def test_processor_kwargs_schema_rejects_trust_remote_code(self):
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.schemas.model import ModelInputConfig
|
||||
|
||||
with pytest.raises(ValueError, match="trust_remote_code"):
|
||||
ModelInputConfig(
|
||||
base_model="some-model",
|
||||
processor_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
def test_processor_kwargs_schema_accepts_valid_keys(self):
|
||||
from axolotl.utils.schemas.model import ModelInputConfig
|
||||
|
||||
cfg = ModelInputConfig(
|
||||
base_model="some-model",
|
||||
processor_kwargs={"image_seq_length": 1120, "max_soft_tokens": 1120},
|
||||
)
|
||||
assert cfg.processor_kwargs == {
|
||||
"image_seq_length": 1120,
|
||||
"max_soft_tokens": 1120,
|
||||
}
|
||||
|
||||
def test_processor_kwargs_schema_accepts_none_and_empty(self):
|
||||
from axolotl.utils.schemas.model import ModelInputConfig
|
||||
|
||||
assert ModelInputConfig(base_model="x").processor_kwargs is None
|
||||
assert (
|
||||
ModelInputConfig(base_model="x", processor_kwargs={}).processor_kwargs == {}
|
||||
)
|
||||
|
||||
def test_processor_kwargs_incompatible_with_mistral_common(self, min_base_cfg):
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
tokenizer_use_mistral_common=True,
|
||||
processor_kwargs={"image_seq_length": 1120},
|
||||
)
|
||||
with pytest.raises(ValueError, match="processor_kwargs"):
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user