feat: add processor_kwargs YAML field forwarded to from_pretrained (#3612)
This commit is contained in:
@@ -23,6 +23,8 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
|||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
if cfg.revision_of_model:
|
if cfg.revision_of_model:
|
||||||
processor_kwargs["revision"] = cfg.revision_of_model
|
processor_kwargs["revision"] = cfg.revision_of_model
|
||||||
|
if cfg.processor_kwargs:
|
||||||
|
processor_kwargs.update(cfg.processor_kwargs)
|
||||||
|
|
||||||
if cfg.tokenizer_use_mistral_common:
|
if cfg.tokenizer_use_mistral_common:
|
||||||
|
|
||||||
|
|||||||
@@ -64,6 +64,12 @@ class ModelInputConfig(BaseModel):
|
|||||||
processor_type: str | None = Field(
|
processor_type: str | None = Field(
|
||||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||||
)
|
)
|
||||||
|
processor_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "kwargs forwarded to the processor's from_pretrained(), overriding processor config (e.g. image_seq_length, min_pixels, etc.)."
|
||||||
|
},
|
||||||
|
)
|
||||||
tokenizer_save_jinja_files: bool | None = Field(
|
tokenizer_save_jinja_files: bool | None = Field(
|
||||||
default=True, # match the default behavior from transformers
|
default=True, # match the default behavior from transformers
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -107,6 +113,22 @@ class ModelInputConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
return trust_remote_code
|
return trust_remote_code
|
||||||
|
|
||||||
|
@field_validator("processor_kwargs")
|
||||||
|
@classmethod
|
||||||
|
def reject_reserved_processor_kwargs(cls, processor_kwargs):
|
||||||
|
if not processor_kwargs:
|
||||||
|
return processor_kwargs
|
||||||
|
reserved = {"revision", "trust_remote_code"}
|
||||||
|
conflicts = reserved.intersection(processor_kwargs)
|
||||||
|
if conflicts:
|
||||||
|
raise ValueError(
|
||||||
|
"Do not set reserved keys "
|
||||||
|
f"{sorted(conflicts)} inside `processor_kwargs`; "
|
||||||
|
"use the top-level `revision_of_model` / `trust_remote_code` "
|
||||||
|
"config keys instead."
|
||||||
|
)
|
||||||
|
return processor_kwargs
|
||||||
|
|
||||||
|
|
||||||
class ModelOutputConfig(BaseModel):
|
class ModelOutputConfig(BaseModel):
|
||||||
"""model save configuration subset"""
|
"""model save configuration subset"""
|
||||||
|
|||||||
@@ -578,6 +578,11 @@ class TrainingValidationMixin:
|
|||||||
"Setting chat_template is not supported with mistral-common tokenizer"
|
"Setting chat_template is not supported with mistral-common tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if data.get("processor_kwargs"):
|
||||||
|
raise ValueError(
|
||||||
|
"processor_kwargs is not supported with mistral-common tokenizer"
|
||||||
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
@@ -133,3 +133,108 @@ class TestRevisionParameter:
|
|||||||
|
|
||||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||||
assert "revision" not in call_kwargs.kwargs
|
assert "revision" not in call_kwargs.kwargs
|
||||||
|
|
||||||
|
@patch("axolotl.loaders.processor.AutoProcessor")
|
||||||
|
def test_load_processor_forwards_processor_kwargs(self, mock_auto_processor):
|
||||||
|
mock_processor = MagicMock()
|
||||||
|
mock_processor.size = {}
|
||||||
|
mock_auto_processor.from_pretrained.return_value = mock_processor
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"processor_config": "some-model",
|
||||||
|
"trust_remote_code": False,
|
||||||
|
"processor_kwargs": {
|
||||||
|
"image_seq_length": 1120,
|
||||||
|
"max_soft_tokens": 1120,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||||
|
|
||||||
|
from axolotl.loaders.processor import load_processor
|
||||||
|
|
||||||
|
load_processor(cfg, tokenizer)
|
||||||
|
|
||||||
|
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||||
|
assert call_kwargs.kwargs.get("image_seq_length") == 1120
|
||||||
|
assert call_kwargs.kwargs.get("max_soft_tokens") == 1120
|
||||||
|
|
||||||
|
@patch("axolotl.loaders.processor.AutoProcessor")
|
||||||
|
def test_load_processor_omits_processor_kwargs_when_unset(
|
||||||
|
self, mock_auto_processor
|
||||||
|
):
|
||||||
|
mock_processor = MagicMock()
|
||||||
|
mock_processor.size = {}
|
||||||
|
mock_auto_processor.from_pretrained.return_value = mock_processor
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"processor_config": "some-model",
|
||||||
|
"trust_remote_code": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||||
|
|
||||||
|
from axolotl.loaders.processor import load_processor
|
||||||
|
|
||||||
|
load_processor(cfg, tokenizer)
|
||||||
|
|
||||||
|
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||||
|
assert "image_seq_length" not in call_kwargs.kwargs
|
||||||
|
assert "max_soft_tokens" not in call_kwargs.kwargs
|
||||||
|
|
||||||
|
def test_processor_kwargs_schema_rejects_revision(self):
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.model import ModelInputConfig
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="revision"):
|
||||||
|
ModelInputConfig(
|
||||||
|
base_model="some-model",
|
||||||
|
processor_kwargs={"revision": "abc123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_processor_kwargs_schema_rejects_trust_remote_code(self):
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.model import ModelInputConfig
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="trust_remote_code"):
|
||||||
|
ModelInputConfig(
|
||||||
|
base_model="some-model",
|
||||||
|
processor_kwargs={"trust_remote_code": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_processor_kwargs_schema_accepts_valid_keys(self):
|
||||||
|
from axolotl.utils.schemas.model import ModelInputConfig
|
||||||
|
|
||||||
|
cfg = ModelInputConfig(
|
||||||
|
base_model="some-model",
|
||||||
|
processor_kwargs={"image_seq_length": 1120, "max_soft_tokens": 1120},
|
||||||
|
)
|
||||||
|
assert cfg.processor_kwargs == {
|
||||||
|
"image_seq_length": 1120,
|
||||||
|
"max_soft_tokens": 1120,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_processor_kwargs_schema_accepts_none_and_empty(self):
|
||||||
|
from axolotl.utils.schemas.model import ModelInputConfig
|
||||||
|
|
||||||
|
assert ModelInputConfig(base_model="x").processor_kwargs is None
|
||||||
|
assert (
|
||||||
|
ModelInputConfig(base_model="x", processor_kwargs={}).processor_kwargs == {}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_processor_kwargs_incompatible_with_mistral_common(self, min_base_cfg):
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.config import validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
cfg = min_base_cfg | DictDefault(
|
||||||
|
tokenizer_use_mistral_common=True,
|
||||||
|
processor_kwargs={"image_seq_length": 1120},
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="processor_kwargs"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user