feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries` (#3625)

* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs+types: address CodeRabbit nitpicks on PR #7

- builders/causal.py: add inline NOTE that multi-dataset configs reuse
  the first dataset's masking knobs (roles_to_train / train_on_eos) for
  all datasets — heterogeneous per-dataset overrides are not supported
  in the MM path today.
- processing_strategies.py: annotate inner scanner helpers
  _match_prefix and _find_end with explicit types (Tensor, int,
  list[int] → bool / tuple[int, bool]) for readability.
- docs/multimodal_assistant_mask.md: renumber the "Commits on this
  branch" list to 1-7 consecutive (previously skipped 3).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(mm-mask): address two CodeRabbit findings on PR #7

1. Schema rejected `train_on_eos: "none"` despite the scanner honoring it.
   `_VALID_TRAIN_ON_EOS` accepts "none" and the design doc lists it, but
   `SFTDataset.train_on_eos` was `Literal["all", "turn", "last"]`, so YAML
   users hit a pydantic ValidationError at config load. Added "none" to
   the Literal and updated the description.

2. `cfg.role_boundaries: []` had split-personality semantics: the strategy
   ctor treated it as "replace built-ins with empty" while the collator
   plumbing treated it as "unset", and both the design doc and the
   MultiModalConfig schema help text promised wholesale replacement for
   any set value. Aligned on opt-in semantics across all four surfaces —
   a non-empty list replaces built-ins wholesale; unset or `[]` falls back
   to built-ins. Rationale: honoring `[]` literally yields all-masked
   labels and zero gradient, which is almost always a typo or leftover
   rather than a deliberate user action. Users who want to disable role
   masking should unset the field or use `train_on_inputs: true`.

   Also sharpened the fallback one-shot warning for strategies without
   built-in boundaries: names the consequence ("only pad and media tokens
   are masked, every other token contributes to loss") and points users
   at `cfg.role_boundaries` + docs/multimodal_assistant_mask.md instead
   of "see axolotl/processing_strategies.py for how to declare
   boundaries."

Files:
- src/axolotl/utils/schemas/datasets.py: Literal adds "none"
- src/axolotl/processing_strategies.py: ctor truthiness check on
  role_boundaries_override; sharpened fallback warning
- src/axolotl/utils/schemas/multimodal.py: role_boundaries description
  now calls out opt-in + empty-list fallback semantics
- docs/multimodal_assistant_mask.md: same clarification in the Semantics
  block; updated the fallback-path detection paragraph to quote the new
  warning text
- tests/test_processing_strategies.py: +2 regressions
  (test_sft_dataset_schema_accepts_all_supported_train_on_eos_values,
  test_empty_role_boundaries_override_falls_back_to_builtin); 63/63 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* doc cleanup

* fix(mm-mask): CodeRabbit findings + lint fix on PR #3625

Pre-commit failure: trailing newline missing on
docs/multimodal_assistant_mask.md (end-of-file-fixer hook).

Six CodeRabbit findings addressed:

1. Scanner: non-trainable role's end marker ignored ``include_end``.
   Under ``train_on_eos="all"``, the shared ``[/INST]`` token (user-end
   with ``include_end=False``, intentionally re-matched as assistant-start)
   leaked into loss via the user branch on Pixtral / Mistral V7 Tekken.
   Fix: gate the non-trainable branch on ``best_match.include_end`` to
   mirror the trainable branch.

2. Gemma3 ``boi_token`` lookup used ``tokenizer.special_tokens_map.get("boi_token")``,
   which never fires on real checkpoints (``special_tokens_map`` only
   holds HF's standard slots — bos/eos/pad/unk/...). Swap to direct
   attribute read ``getattr(tokenizer, "boi_token", None)``, matching
   what ``transformers.models.gemma3.processing_gemma3`` itself does.
   Updated the ``_gemma_tokenizer`` test fixture to mirror real-model
   shape so the test exercises the production code path.

3. GLM dispatcher only registered ``Glm46VProcessor`` (GLM-4.6V /
   GLM-4.7V). Real ``Glm4vProcessor`` (GLM-4V / GLM-4.1V) users fell
   through to the base fallback. Both processors ship identical
   media-token markers, so register both under the shared
   ``Glm4vProcessingStrategy`` with independent try/except import blocks.
   Updated class docstring. +2 dispatcher regressions.

4. Gemma3 ``process_labels`` hardcoded 262144 for the soft image token.
   Resolve dynamically via ``tokenizer.convert_tokens_to_ids("<image_soft_token>")``
   with unk-id guard; fall back to 262144 only if the string isn't in
   vocab. Mirrors ``Gemma4ProcessingStrategy.process_labels`` pattern.

5. ``build_collator`` was called twice per ``build()`` (eval + train
   passes), producing two identical ``MM collator: ...`` INFO banners on
   startup. Gate the log on ``is_eval=False`` so only the training pass
   emits it.

6. Removed unused ``_mistral_common_stub`` pytest fixture (13 refs → 0,
   always returned ``None``; the dispatcher already handles missing
   ``mistral_common`` via lazy import + ``try/except``). Added
   ``test_scanner_train_on_eos_all_with_non_trainable_include_end_false``
   — a focused scanner-level lock-in for finding #1, independent of any
   specific VLM strategy.

Test count: 63 → 68 passing. Local ``pre-commit run --all-files`` green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* chore(mm-mask): hoist .tolist() out of scanner; shorten comments/docstrings

- Scanner perf: convert labels[i] to a Python list once per row so
  _match_prefix / _find_end operate on list slices instead of
  re-materializing Tensor slices via .tolist() on every probe. Cuts
  O(n*boundaries) CPython↔C boundary crossings per batch.
- Markdown lint (MD001, MD040): promote two h3 section headings to h2
  under the h1; add `text` language to the verify-at-runtime fenced block.
- Shorten verbose comments/docstrings added in recent commits to
  bare-minimum "why" notes matching the repo's existing style.

68/68 tests, 8/8 pre-commit hooks still pass.
This commit is contained in:
thad0ctor
2026-05-05 08:25:39 -07:00
committed by GitHub
parent c15f6cffe2
commit 5352d41d32
6 changed files with 2156 additions and 230 deletions

View File

@@ -0,0 +1,84 @@
# Multimodal assistant-only loss masking
## Correct placement
```yaml
# Top-level: only train_on_inputs lives here.
train_on_inputs: false
datasets:
- path: data/train.jsonl
type: chat_template
roles_to_train: # per-dataset — this is what the MM scanner reads
- assistant
train_on_eos: turn # per-dataset — same
test_datasets:
- path: data/val.jsonl
type: chat_template
split: train
roles_to_train:
- assistant
train_on_eos: turn
```
## How to verify at runtime
`build_collator` logs the resolved knobs at INFO:
```text
MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none
```
If `roles_to_train` logs as `None`, the YAML knobs are not reaching the
scanner — check that they are under `datasets[0]`, not at the root.
Each verified strategy additionally logs its resolved boundary token ids at
strategy init (e.g. `<|turn>model``[105, 4368]`, `<turn|>``[106]` for
Gemma 4). If a strategy emits the "has no built-in role boundaries ... only
pad and media tokens are masked" one-shot warning instead, it is on the
fallback path — declare per-role markers in YAML via `cfg.role_boundaries`
(below) to activate masking. The strategies currently on this path are
listed in the audit table above under `fallback + warn`.
## Config-based override: `cfg.role_boundaries`
For the "unverified" strategies above, or for custom chat templates that
don't match a built-in strategy's markers, users can declare role boundaries
directly in YAML without subclassing:
```yaml
role_boundaries:
- role: assistant
start: "<|turn>model"
end: "<turn|>"
- role: user
start: "<|turn>user"
end: "<turn|>"
# Optional keys:
# include_start: false # default False
# include_end: true # default True, respects cfg.train_on_eos
# end: eos_token # sentinel: resolves to tokenizer.eos_token_id
# end: null # span runs to end of sequence
```
Semantics:
- `start` and `end` are literal strings; axolotl encodes them at strategy
init via `tokenizer.encode(..., add_special_tokens=False)` and logs the
resolved token-id sequences at INFO level.
- The special value `end: eos_token` is the portable way to express
"Pixtral-style assistant turns end at EOS" without hard-coding an id.
- `role_boundaries` is an **opt-in override**. A non-empty list **replaces**
the strategy's built-in declarations wholesale (partial overlays are
intentionally unsupported — they're hard to reason about at review time).
Leaving the field unset *or* setting it to an empty list (`[]`) both mean
"use the strategy's built-ins." Writing `role_boundaries: []` is almost
always a typo or leftover — honoring it literally would produce all-masked
labels and zero gradient, so it is treated the same as unset.
- `cfg.roles_to_train` still governs which declared roles contribute to
loss. You can declare `user` and `assistant` boundaries and set
`roles_to_train: ["assistant"]` to have the scanner correctly identify
user spans as masking boundaries without training on their content.
- Invalid specs fail loudly at strategy init (missing `role`/`start`,
unencodable markers), not silently at loss-compute time.

View File

@@ -515,12 +515,53 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
# Mirror ChatTemplateStrategy: per-dataset masking knobs from first MM dataset, else global cfg.
# NOTE: Multi-dataset configs use the first dataset's masking knobs for all datasets;
# heterogeneous per-dataset overrides are not supported in the MM path today.
ds_entries = self.cfg.datasets or []
ds_cfg = ds_entries[0] if ds_entries else None
def _ds_get(cfg_obj, key):
# Handle DictDefault / dict / pydantic uniformly:
# dict-style .get first, then attribute access.
if cfg_obj is None:
return None
if hasattr(cfg_obj, "get"):
try:
return cfg_obj.get(key)
except (AttributeError, KeyError, TypeError):
pass
return getattr(cfg_obj, key, None)
roles_to_train = _ds_get(ds_cfg, "roles_to_train")
train_on_eos = _ds_get(ds_cfg, "train_on_eos")
# cfg.role_boundaries replaces the strategy's built-in markers.
role_boundaries_override = None
if self.cfg.role_boundaries:
role_boundaries_override = list(self.cfg.role_boundaries)
# build() calls build_collator twice (eval + train); log once.
if not is_eval:
LOG.info(
"MM collator: train_on_inputs=%s roles_to_train=%s "
"train_on_eos=%s role_boundaries_override=%s",
bool(self.cfg.train_on_inputs),
roles_to_train,
train_on_eos,
"set" if role_boundaries_override else "none",
)
kwargs["processing_strategy"] = get_processing_strategy(
self.processor,
training_args.chat_template,
self.cfg.chat_template,
image_size=training_args.image_size,
image_resize_algorithm=training_args.image_resize_algorithm,
train_on_inputs=bool(self.cfg.train_on_inputs),
roles_to_train=roles_to_train,
train_on_eos=train_on_eos,
role_boundaries_override=role_boundaries_override,
)
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening

File diff suppressed because it is too large Load Diff

View File

@@ -166,10 +166,10 @@ class SFTDataset(BaseModel):
"description": "Roles to train on. The tokens from these roles will be considered for the loss."
},
)
train_on_eos: Literal["all", "turn", "last"] | None = Field(
train_on_eos: Literal["all", "turn", "last", "none"] | None = Field(
default=None,
json_schema_extra={
"description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation"
"description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation, none: never train on EOS tokens"
},
)
roles: dict[str, list[str]] | None = Field(

View File

@@ -6,6 +6,57 @@ from PIL.Image import Resampling
from pydantic import BaseModel, Field, field_validator
class RoleBoundarySpec(BaseModel):
"""One ``cfg.role_boundaries`` row; see docs/multimodal_assistant_mask.md."""
role: str = Field(
json_schema_extra={
"description": (
"Role name as it appears in cfg.roles_to_train (e.g. "
"'assistant', 'user', 'system', 'tool', 'ipython')."
)
},
)
start: str = Field(
json_schema_extra={
"description": (
"Literal string that marks the start of this role's span in "
"the rendered chat template. Tokenized via "
"``tokenizer.encode(..., add_special_tokens=False)`` at "
"strategy init."
)
},
)
end: str | None = Field(
default=None,
json_schema_extra={
"description": (
"Literal string that marks the end of this role's span. "
"Set to ``eos_token`` to terminate at the tokenizer's EOS. "
"Leave unset / null to terminate at end-of-sequence."
)
},
)
include_start: bool = Field(
default=False,
json_schema_extra={
"description": (
"Whether the start marker tokens contribute to loss on "
"trainable turns. Default False."
)
},
)
include_end: bool = Field(
default=True,
json_schema_extra={
"description": (
"Whether the end marker tokens contribute to loss on "
"trainable turns (honoring cfg.train_on_eos). Default True."
)
},
)
class MultiModalConfig(BaseModel):
"""Multi-modal configuration subset"""
@@ -26,6 +77,17 @@ class MultiModalConfig(BaseModel):
"description": "The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details."
},
)
role_boundaries: list[RoleBoundarySpec] | None = Field(
default=None,
json_schema_extra={
"description": (
"Opt-in override for the MM mask scanner's per-role boundary "
"markers. Non-empty list replaces built-ins wholesale; unset "
"or empty falls back to built-ins. See "
"docs/multimodal_assistant_mask.md."
)
},
)
@field_validator("image_resize_algorithm", mode="before")
@classmethod

File diff suppressed because it is too large Load Diff