feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries` (#3625)
* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` / `cfg.train_on_eos` in the multimodal training path. Before this branch, only Gemma 3n honored these knobs; every other VLM trained on the full sequence regardless of config. Also adds `cfg.role_boundaries` YAML override so users can declare per-role markers without subclassing. What changed ------------ - `ProcessingStrategy` gains a declarative boundary scanner. Each strategy declares per-role start/end markers via `_build_role_boundaries`; the shared scanner honors `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last"). - New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4, Pixtral, Mistral V7 Tekken. - Refactored: Gemma 3 (previously no role masking), Gemma 3n (previously ad-hoc scanner, now shared). - Strategies whose boundary tokens couldn't be verified offline (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl fallback) retain legacy behavior and emit a one-shot warning. Users can enable masking on them via `cfg.role_boundaries`. - Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]` token between user-end and assistant-start via `include_end=False` + scanner rewind. See `docs/multimodal_assistant_mask.md` for the full audit table, root-cause analysis, and design rationale. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` / `cfg.train_on_eos` in the multimodal training path. Before this branch, only Gemma 3n honored these knobs; every other VLM trained on the full sequence regardless of config. Also adds `cfg.role_boundaries` YAML override so users can declare per-role markers without subclassing. What changed ------------ - `ProcessingStrategy` gains a declarative boundary scanner. Each strategy declares per-role start/end markers via `_build_role_boundaries`; the shared scanner honors `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last"). - New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4, Pixtral, Mistral V7 Tekken. - Refactored: Gemma 3 (previously no role masking), Gemma 3n (previously ad-hoc scanner, now shared). - Strategies whose boundary tokens couldn't be verified offline (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl fallback) retain legacy behavior and emit a one-shot warning. Users can enable masking on them via `cfg.role_boundaries`. - Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]` token between user-end and assistant-start via `include_end=False` + scanner rewind. See `docs/multimodal_assistant_mask.md` for the full audit table, root-cause analysis, and design rationale. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * docs+types: address CodeRabbit nitpicks on PR #7 - builders/causal.py: add inline NOTE that multi-dataset configs reuse the first dataset's masking knobs (roles_to_train / train_on_eos) for all datasets — heterogeneous per-dataset overrides are not supported in the MM path today. - processing_strategies.py: annotate inner scanner helpers _match_prefix and _find_end with explicit types (Tensor, int, list[int] → bool / tuple[int, bool]) for readability. - docs/multimodal_assistant_mask.md: renumber the "Commits on this branch" list to 1-7 consecutive (previously skipped 3). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(mm-mask): address two CodeRabbit findings on PR #7 1. Schema rejected `train_on_eos: "none"` despite the scanner honoring it. `_VALID_TRAIN_ON_EOS` accepts "none" and the design doc lists it, but `SFTDataset.train_on_eos` was `Literal["all", "turn", "last"]`, so YAML users hit a pydantic ValidationError at config load. Added "none" to the Literal and updated the description. 2. `cfg.role_boundaries: []` had split-personality semantics: the strategy ctor treated it as "replace built-ins with empty" while the collator plumbing treated it as "unset", and both the design doc and the MultiModalConfig schema help text promised wholesale replacement for any set value. Aligned on opt-in semantics across all four surfaces — a non-empty list replaces built-ins wholesale; unset or `[]` falls back to built-ins. Rationale: honoring `[]` literally yields all-masked labels and zero gradient, which is almost always a typo or leftover rather than a deliberate user action. Users who want to disable role masking should unset the field or use `train_on_inputs: true`. Also sharpened the fallback one-shot warning for strategies without built-in boundaries: names the consequence ("only pad and media tokens are masked, every other token contributes to loss") and points users at `cfg.role_boundaries` + docs/multimodal_assistant_mask.md instead of "see axolotl/processing_strategies.py for how to declare boundaries." Files: - src/axolotl/utils/schemas/datasets.py: Literal adds "none" - src/axolotl/processing_strategies.py: ctor truthiness check on role_boundaries_override; sharpened fallback warning - src/axolotl/utils/schemas/multimodal.py: role_boundaries description now calls out opt-in + empty-list fallback semantics - docs/multimodal_assistant_mask.md: same clarification in the Semantics block; updated the fallback-path detection paragraph to quote the new warning text - tests/test_processing_strategies.py: +2 regressions (test_sft_dataset_schema_accepts_all_supported_train_on_eos_values, test_empty_role_boundaries_override_falls_back_to_builtin); 63/63 pass Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * doc cleanup * fix(mm-mask): CodeRabbit findings + lint fix on PR #3625 Pre-commit failure: trailing newline missing on docs/multimodal_assistant_mask.md (end-of-file-fixer hook). Six CodeRabbit findings addressed: 1. Scanner: non-trainable role's end marker ignored ``include_end``. Under ``train_on_eos="all"``, the shared ``[/INST]`` token (user-end with ``include_end=False``, intentionally re-matched as assistant-start) leaked into loss via the user branch on Pixtral / Mistral V7 Tekken. Fix: gate the non-trainable branch on ``best_match.include_end`` to mirror the trainable branch. 2. Gemma3 ``boi_token`` lookup used ``tokenizer.special_tokens_map.get("boi_token")``, which never fires on real checkpoints (``special_tokens_map`` only holds HF's standard slots — bos/eos/pad/unk/...). Swap to direct attribute read ``getattr(tokenizer, "boi_token", None)``, matching what ``transformers.models.gemma3.processing_gemma3`` itself does. Updated the ``_gemma_tokenizer`` test fixture to mirror real-model shape so the test exercises the production code path. 3. GLM dispatcher only registered ``Glm46VProcessor`` (GLM-4.6V / GLM-4.7V). Real ``Glm4vProcessor`` (GLM-4V / GLM-4.1V) users fell through to the base fallback. Both processors ship identical media-token markers, so register both under the shared ``Glm4vProcessingStrategy`` with independent try/except import blocks. Updated class docstring. +2 dispatcher regressions. 4. Gemma3 ``process_labels`` hardcoded 262144 for the soft image token. Resolve dynamically via ``tokenizer.convert_tokens_to_ids("<image_soft_token>")`` with unk-id guard; fall back to 262144 only if the string isn't in vocab. Mirrors ``Gemma4ProcessingStrategy.process_labels`` pattern. 5. ``build_collator`` was called twice per ``build()`` (eval + train passes), producing two identical ``MM collator: ...`` INFO banners on startup. Gate the log on ``is_eval=False`` so only the training pass emits it. 6. Removed unused ``_mistral_common_stub`` pytest fixture (13 refs → 0, always returned ``None``; the dispatcher already handles missing ``mistral_common`` via lazy import + ``try/except``). Added ``test_scanner_train_on_eos_all_with_non_trainable_include_end_false`` — a focused scanner-level lock-in for finding #1, independent of any specific VLM strategy. Test count: 63 → 68 passing. Local ``pre-commit run --all-files`` green. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * chore(mm-mask): hoist .tolist() out of scanner; shorten comments/docstrings - Scanner perf: convert labels[i] to a Python list once per row so _match_prefix / _find_end operate on list slices instead of re-materializing Tensor slices via .tolist() on every probe. Cuts O(n*boundaries) CPython↔C boundary crossings per batch. - Markdown lint (MD001, MD040): promote two h3 section headings to h2 under the h1; add `text` language to the verify-at-runtime fenced block. - Shorten verbose comments/docstrings added in recent commits to bare-minimum "why" notes matching the repo's existing style. 68/68 tests, 8/8 pre-commit hooks still pass.
This commit is contained in:
84
docs/multimodal_assistant_mask.md
Normal file
84
docs/multimodal_assistant_mask.md
Normal file
@@ -0,0 +1,84 @@
|
||||
# Multimodal assistant-only loss masking
|
||||
|
||||
## Correct placement
|
||||
|
||||
```yaml
|
||||
# Top-level: only train_on_inputs lives here.
|
||||
train_on_inputs: false
|
||||
|
||||
datasets:
|
||||
- path: data/train.jsonl
|
||||
type: chat_template
|
||||
roles_to_train: # per-dataset — this is what the MM scanner reads
|
||||
- assistant
|
||||
train_on_eos: turn # per-dataset — same
|
||||
|
||||
test_datasets:
|
||||
- path: data/val.jsonl
|
||||
type: chat_template
|
||||
split: train
|
||||
roles_to_train:
|
||||
- assistant
|
||||
train_on_eos: turn
|
||||
```
|
||||
|
||||
## How to verify at runtime
|
||||
|
||||
`build_collator` logs the resolved knobs at INFO:
|
||||
|
||||
```text
|
||||
MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none
|
||||
```
|
||||
|
||||
If `roles_to_train` logs as `None`, the YAML knobs are not reaching the
|
||||
scanner — check that they are under `datasets[0]`, not at the root.
|
||||
|
||||
Each verified strategy additionally logs its resolved boundary token ids at
|
||||
strategy init (e.g. `<|turn>model` → `[105, 4368]`, `<turn|>` → `[106]` for
|
||||
Gemma 4). If a strategy emits the "has no built-in role boundaries ... only
|
||||
pad and media tokens are masked" one-shot warning instead, it is on the
|
||||
fallback path — declare per-role markers in YAML via `cfg.role_boundaries`
|
||||
(below) to activate masking. The strategies currently on this path are
|
||||
listed in the audit table above under `fallback + warn`.
|
||||
|
||||
## Config-based override: `cfg.role_boundaries`
|
||||
|
||||
For the "unverified" strategies above, or for custom chat templates that
|
||||
don't match a built-in strategy's markers, users can declare role boundaries
|
||||
directly in YAML without subclassing:
|
||||
|
||||
```yaml
|
||||
role_boundaries:
|
||||
- role: assistant
|
||||
start: "<|turn>model"
|
||||
end: "<turn|>"
|
||||
- role: user
|
||||
start: "<|turn>user"
|
||||
end: "<turn|>"
|
||||
# Optional keys:
|
||||
# include_start: false # default False
|
||||
# include_end: true # default True, respects cfg.train_on_eos
|
||||
# end: eos_token # sentinel: resolves to tokenizer.eos_token_id
|
||||
# end: null # span runs to end of sequence
|
||||
```
|
||||
|
||||
Semantics:
|
||||
|
||||
- `start` and `end` are literal strings; axolotl encodes them at strategy
|
||||
init via `tokenizer.encode(..., add_special_tokens=False)` and logs the
|
||||
resolved token-id sequences at INFO level.
|
||||
- The special value `end: eos_token` is the portable way to express
|
||||
"Pixtral-style assistant turns end at EOS" without hard-coding an id.
|
||||
- `role_boundaries` is an **opt-in override**. A non-empty list **replaces**
|
||||
the strategy's built-in declarations wholesale (partial overlays are
|
||||
intentionally unsupported — they're hard to reason about at review time).
|
||||
Leaving the field unset *or* setting it to an empty list (`[]`) both mean
|
||||
"use the strategy's built-ins." Writing `role_boundaries: []` is almost
|
||||
always a typo or leftover — honoring it literally would produce all-masked
|
||||
labels and zero gradient, so it is treated the same as unset.
|
||||
- `cfg.roles_to_train` still governs which declared roles contribute to
|
||||
loss. You can declare `user` and `assistant` boundaries and set
|
||||
`roles_to_train: ["assistant"]` to have the scanner correctly identify
|
||||
user spans as masking boundaries without training on their content.
|
||||
- Invalid specs fail loudly at strategy init (missing `role`/`start`,
|
||||
unencodable markers), not silently at loss-compute time.
|
||||
@@ -515,12 +515,53 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
if self.cfg.processor_type and self.processor:
|
||||
collator = MultiModalChatDataCollator
|
||||
# Mirror ChatTemplateStrategy: per-dataset masking knobs from first MM dataset, else global cfg.
|
||||
# NOTE: Multi-dataset configs use the first dataset's masking knobs for all datasets;
|
||||
# heterogeneous per-dataset overrides are not supported in the MM path today.
|
||||
ds_entries = self.cfg.datasets or []
|
||||
ds_cfg = ds_entries[0] if ds_entries else None
|
||||
|
||||
def _ds_get(cfg_obj, key):
|
||||
# Handle DictDefault / dict / pydantic uniformly:
|
||||
# dict-style .get first, then attribute access.
|
||||
if cfg_obj is None:
|
||||
return None
|
||||
if hasattr(cfg_obj, "get"):
|
||||
try:
|
||||
return cfg_obj.get(key)
|
||||
except (AttributeError, KeyError, TypeError):
|
||||
pass
|
||||
return getattr(cfg_obj, key, None)
|
||||
|
||||
roles_to_train = _ds_get(ds_cfg, "roles_to_train")
|
||||
train_on_eos = _ds_get(ds_cfg, "train_on_eos")
|
||||
|
||||
# cfg.role_boundaries replaces the strategy's built-in markers.
|
||||
role_boundaries_override = None
|
||||
if self.cfg.role_boundaries:
|
||||
role_boundaries_override = list(self.cfg.role_boundaries)
|
||||
|
||||
# build() calls build_collator twice (eval + train); log once.
|
||||
if not is_eval:
|
||||
LOG.info(
|
||||
"MM collator: train_on_inputs=%s roles_to_train=%s "
|
||||
"train_on_eos=%s role_boundaries_override=%s",
|
||||
bool(self.cfg.train_on_inputs),
|
||||
roles_to_train,
|
||||
train_on_eos,
|
||||
"set" if role_boundaries_override else "none",
|
||||
)
|
||||
|
||||
kwargs["processing_strategy"] = get_processing_strategy(
|
||||
self.processor,
|
||||
training_args.chat_template,
|
||||
self.cfg.chat_template,
|
||||
image_size=training_args.image_size,
|
||||
image_resize_algorithm=training_args.image_resize_algorithm,
|
||||
train_on_inputs=bool(self.cfg.train_on_inputs),
|
||||
roles_to_train=roles_to_train,
|
||||
train_on_eos=train_on_eos,
|
||||
role_boundaries_override=role_boundaries_override,
|
||||
)
|
||||
elif self.cfg.batch_flattening:
|
||||
collator = DataCollatorWithFlattening
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -166,10 +166,10 @@ class SFTDataset(BaseModel):
|
||||
"description": "Roles to train on. The tokens from these roles will be considered for the loss."
|
||||
},
|
||||
)
|
||||
train_on_eos: Literal["all", "turn", "last"] | None = Field(
|
||||
train_on_eos: Literal["all", "turn", "last", "none"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation"
|
||||
"description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation, none: never train on EOS tokens"
|
||||
},
|
||||
)
|
||||
roles: dict[str, list[str]] | None = Field(
|
||||
|
||||
@@ -6,6 +6,57 @@ from PIL.Image import Resampling
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class RoleBoundarySpec(BaseModel):
|
||||
"""One ``cfg.role_boundaries`` row; see docs/multimodal_assistant_mask.md."""
|
||||
|
||||
role: str = Field(
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Role name as it appears in cfg.roles_to_train (e.g. "
|
||||
"'assistant', 'user', 'system', 'tool', 'ipython')."
|
||||
)
|
||||
},
|
||||
)
|
||||
start: str = Field(
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Literal string that marks the start of this role's span in "
|
||||
"the rendered chat template. Tokenized via "
|
||||
"``tokenizer.encode(..., add_special_tokens=False)`` at "
|
||||
"strategy init."
|
||||
)
|
||||
},
|
||||
)
|
||||
end: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Literal string that marks the end of this role's span. "
|
||||
"Set to ``eos_token`` to terminate at the tokenizer's EOS. "
|
||||
"Leave unset / null to terminate at end-of-sequence."
|
||||
)
|
||||
},
|
||||
)
|
||||
include_start: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Whether the start marker tokens contribute to loss on "
|
||||
"trainable turns. Default False."
|
||||
)
|
||||
},
|
||||
)
|
||||
include_end: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Whether the end marker tokens contribute to loss on "
|
||||
"trainable turns (honoring cfg.train_on_eos). Default True."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class MultiModalConfig(BaseModel):
|
||||
"""Multi-modal configuration subset"""
|
||||
|
||||
@@ -26,6 +77,17 @@ class MultiModalConfig(BaseModel):
|
||||
"description": "The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details."
|
||||
},
|
||||
)
|
||||
role_boundaries: list[RoleBoundarySpec] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Opt-in override for the MM mask scanner's per-role boundary "
|
||||
"markers. Non-empty list replaces built-ins wholesale; unset "
|
||||
"or empty falls back to built-ins. See "
|
||||
"docs/multimodal_assistant_mask.md."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@field_validator("image_resize_algorithm", mode="before")
|
||||
@classmethod
|
||||
|
||||
1164
tests/test_processing_strategies.py
Normal file
1164
tests/test_processing_strategies.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user