From 5352d41d32f4245bbb8b0140990ac0f6b77f1ab1 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Tue, 5 May 2026 08:25:39 -0700 Subject: [PATCH] feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries` (#3625) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` / `cfg.train_on_eos` in the multimodal training path. Before this branch, only Gemma 3n honored these knobs; every other VLM trained on the full sequence regardless of config. Also adds `cfg.role_boundaries` YAML override so users can declare per-role markers without subclassing. What changed ------------ - `ProcessingStrategy` gains a declarative boundary scanner. Each strategy declares per-role start/end markers via `_build_role_boundaries`; the shared scanner honors `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last"). - New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4, Pixtral, Mistral V7 Tekken. - Refactored: Gemma 3 (previously no role masking), Gemma 3n (previously ad-hoc scanner, now shared). - Strategies whose boundary tokens couldn't be verified offline (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl fallback) retain legacy behavior and emit a one-shot warning. Users can enable masking on them via `cfg.role_boundaries`. - Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]` token between user-end and assistant-start via `include_end=False` + scanner rewind. See `docs/multimodal_assistant_mask.md` for the full audit table, root-cause analysis, and design rationale. Co-Authored-By: Claude Opus 4.7 (1M context) * feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` / `cfg.train_on_eos` in the multimodal training path. Before this branch, only Gemma 3n honored these knobs; every other VLM trained on the full sequence regardless of config. Also adds `cfg.role_boundaries` YAML override so users can declare per-role markers without subclassing. What changed ------------ - `ProcessingStrategy` gains a declarative boundary scanner. Each strategy declares per-role start/end markers via `_build_role_boundaries`; the shared scanner honors `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last"). - New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4, Pixtral, Mistral V7 Tekken. - Refactored: Gemma 3 (previously no role masking), Gemma 3n (previously ad-hoc scanner, now shared). - Strategies whose boundary tokens couldn't be verified offline (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl fallback) retain legacy behavior and emit a one-shot warning. Users can enable masking on them via `cfg.role_boundaries`. - Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]` token between user-end and assistant-start via `include_end=False` + scanner rewind. See `docs/multimodal_assistant_mask.md` for the full audit table, root-cause analysis, and design rationale. Co-Authored-By: Claude Opus 4.7 (1M context) * docs+types: address CodeRabbit nitpicks on PR #7 - builders/causal.py: add inline NOTE that multi-dataset configs reuse the first dataset's masking knobs (roles_to_train / train_on_eos) for all datasets — heterogeneous per-dataset overrides are not supported in the MM path today. - processing_strategies.py: annotate inner scanner helpers _match_prefix and _find_end with explicit types (Tensor, int, list[int] → bool / tuple[int, bool]) for readability. - docs/multimodal_assistant_mask.md: renumber the "Commits on this branch" list to 1-7 consecutive (previously skipped 3). Co-Authored-By: Claude Opus 4.7 (1M context) * fix(mm-mask): address two CodeRabbit findings on PR #7 1. Schema rejected `train_on_eos: "none"` despite the scanner honoring it. `_VALID_TRAIN_ON_EOS` accepts "none" and the design doc lists it, but `SFTDataset.train_on_eos` was `Literal["all", "turn", "last"]`, so YAML users hit a pydantic ValidationError at config load. Added "none" to the Literal and updated the description. 2. `cfg.role_boundaries: []` had split-personality semantics: the strategy ctor treated it as "replace built-ins with empty" while the collator plumbing treated it as "unset", and both the design doc and the MultiModalConfig schema help text promised wholesale replacement for any set value. Aligned on opt-in semantics across all four surfaces — a non-empty list replaces built-ins wholesale; unset or `[]` falls back to built-ins. Rationale: honoring `[]` literally yields all-masked labels and zero gradient, which is almost always a typo or leftover rather than a deliberate user action. Users who want to disable role masking should unset the field or use `train_on_inputs: true`. Also sharpened the fallback one-shot warning for strategies without built-in boundaries: names the consequence ("only pad and media tokens are masked, every other token contributes to loss") and points users at `cfg.role_boundaries` + docs/multimodal_assistant_mask.md instead of "see axolotl/processing_strategies.py for how to declare boundaries." Files: - src/axolotl/utils/schemas/datasets.py: Literal adds "none" - src/axolotl/processing_strategies.py: ctor truthiness check on role_boundaries_override; sharpened fallback warning - src/axolotl/utils/schemas/multimodal.py: role_boundaries description now calls out opt-in + empty-list fallback semantics - docs/multimodal_assistant_mask.md: same clarification in the Semantics block; updated the fallback-path detection paragraph to quote the new warning text - tests/test_processing_strategies.py: +2 regressions (test_sft_dataset_schema_accepts_all_supported_train_on_eos_values, test_empty_role_boundaries_override_falls_back_to_builtin); 63/63 pass Co-Authored-By: Claude Opus 4.7 (1M context) * doc cleanup * fix(mm-mask): CodeRabbit findings + lint fix on PR #3625 Pre-commit failure: trailing newline missing on docs/multimodal_assistant_mask.md (end-of-file-fixer hook). Six CodeRabbit findings addressed: 1. Scanner: non-trainable role's end marker ignored ``include_end``. Under ``train_on_eos="all"``, the shared ``[/INST]`` token (user-end with ``include_end=False``, intentionally re-matched as assistant-start) leaked into loss via the user branch on Pixtral / Mistral V7 Tekken. Fix: gate the non-trainable branch on ``best_match.include_end`` to mirror the trainable branch. 2. Gemma3 ``boi_token`` lookup used ``tokenizer.special_tokens_map.get("boi_token")``, which never fires on real checkpoints (``special_tokens_map`` only holds HF's standard slots — bos/eos/pad/unk/...). Swap to direct attribute read ``getattr(tokenizer, "boi_token", None)``, matching what ``transformers.models.gemma3.processing_gemma3`` itself does. Updated the ``_gemma_tokenizer`` test fixture to mirror real-model shape so the test exercises the production code path. 3. GLM dispatcher only registered ``Glm46VProcessor`` (GLM-4.6V / GLM-4.7V). Real ``Glm4vProcessor`` (GLM-4V / GLM-4.1V) users fell through to the base fallback. Both processors ship identical media-token markers, so register both under the shared ``Glm4vProcessingStrategy`` with independent try/except import blocks. Updated class docstring. +2 dispatcher regressions. 4. Gemma3 ``process_labels`` hardcoded 262144 for the soft image token. Resolve dynamically via ``tokenizer.convert_tokens_to_ids("")`` with unk-id guard; fall back to 262144 only if the string isn't in vocab. Mirrors ``Gemma4ProcessingStrategy.process_labels`` pattern. 5. ``build_collator`` was called twice per ``build()`` (eval + train passes), producing two identical ``MM collator: ...`` INFO banners on startup. Gate the log on ``is_eval=False`` so only the training pass emits it. 6. Removed unused ``_mistral_common_stub`` pytest fixture (13 refs → 0, always returned ``None``; the dispatcher already handles missing ``mistral_common`` via lazy import + ``try/except``). Added ``test_scanner_train_on_eos_all_with_non_trainable_include_end_false`` — a focused scanner-level lock-in for finding #1, independent of any specific VLM strategy. Test count: 63 → 68 passing. Local ``pre-commit run --all-files`` green. Co-Authored-By: Claude Opus 4.7 (1M context) * chore(mm-mask): hoist .tolist() out of scanner; shorten comments/docstrings - Scanner perf: convert labels[i] to a Python list once per row so _match_prefix / _find_end operate on list slices instead of re-materializing Tensor slices via .tolist() on every probe. Cuts O(n*boundaries) CPython↔C boundary crossings per batch. - Markdown lint (MD001, MD040): promote two h3 section headings to h2 under the h1; add `text` language to the verify-at-runtime fenced block. - Shorten verbose comments/docstrings added in recent commits to bare-minimum "why" notes matching the repo's existing style. 68/68 tests, 8/8 pre-commit hooks still pass. --- docs/multimodal_assistant_mask.md | 84 ++ src/axolotl/core/builders/causal.py | 41 + src/axolotl/processing_strategies.py | 1031 +++++++++++++++----- src/axolotl/utils/schemas/datasets.py | 4 +- src/axolotl/utils/schemas/multimodal.py | 62 ++ tests/test_processing_strategies.py | 1164 +++++++++++++++++++++++ 6 files changed, 2156 insertions(+), 230 deletions(-) create mode 100644 docs/multimodal_assistant_mask.md create mode 100644 tests/test_processing_strategies.py diff --git a/docs/multimodal_assistant_mask.md b/docs/multimodal_assistant_mask.md new file mode 100644 index 000000000..339ab420f --- /dev/null +++ b/docs/multimodal_assistant_mask.md @@ -0,0 +1,84 @@ +# Multimodal assistant-only loss masking + +## Correct placement + +```yaml +# Top-level: only train_on_inputs lives here. +train_on_inputs: false + +datasets: + - path: data/train.jsonl + type: chat_template + roles_to_train: # per-dataset — this is what the MM scanner reads + - assistant + train_on_eos: turn # per-dataset — same + +test_datasets: + - path: data/val.jsonl + type: chat_template + split: train + roles_to_train: + - assistant + train_on_eos: turn +``` + +## How to verify at runtime + +`build_collator` logs the resolved knobs at INFO: + +```text +MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none +``` + +If `roles_to_train` logs as `None`, the YAML knobs are not reaching the +scanner — check that they are under `datasets[0]`, not at the root. + +Each verified strategy additionally logs its resolved boundary token ids at +strategy init (e.g. `<|turn>model` → `[105, 4368]`, `` → `[106]` for +Gemma 4). If a strategy emits the "has no built-in role boundaries ... only +pad and media tokens are masked" one-shot warning instead, it is on the +fallback path — declare per-role markers in YAML via `cfg.role_boundaries` +(below) to activate masking. The strategies currently on this path are +listed in the audit table above under `fallback + warn`. + +## Config-based override: `cfg.role_boundaries` + +For the "unverified" strategies above, or for custom chat templates that +don't match a built-in strategy's markers, users can declare role boundaries +directly in YAML without subclassing: + +```yaml +role_boundaries: + - role: assistant + start: "<|turn>model" + end: "" + - role: user + start: "<|turn>user" + end: "" + # Optional keys: + # include_start: false # default False + # include_end: true # default True, respects cfg.train_on_eos + # end: eos_token # sentinel: resolves to tokenizer.eos_token_id + # end: null # span runs to end of sequence +``` + +Semantics: + +- `start` and `end` are literal strings; axolotl encodes them at strategy + init via `tokenizer.encode(..., add_special_tokens=False)` and logs the + resolved token-id sequences at INFO level. +- The special value `end: eos_token` is the portable way to express + "Pixtral-style assistant turns end at EOS" without hard-coding an id. +- `role_boundaries` is an **opt-in override**. A non-empty list **replaces** + the strategy's built-in declarations wholesale (partial overlays are + intentionally unsupported — they're hard to reason about at review time). + Leaving the field unset *or* setting it to an empty list (`[]`) both mean + "use the strategy's built-ins." Writing `role_boundaries: []` is almost + always a typo or leftover — honoring it literally would produce all-masked + labels and zero gradient, so it is treated the same as unset. +- `cfg.roles_to_train` still governs which declared roles contribute to + loss. You can declare `user` and `assistant` boundaries and set + `roles_to_train: ["assistant"]` to have the scanner correctly identify + user spans as masking boundaries without training on their content. +- Invalid specs fail loudly at strategy init (missing `role`/`start`, + unencodable markers), not silently at loss-compute time. diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 15624173d..aa1678523 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -515,12 +515,53 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): else: if self.cfg.processor_type and self.processor: collator = MultiModalChatDataCollator + # Mirror ChatTemplateStrategy: per-dataset masking knobs from first MM dataset, else global cfg. + # NOTE: Multi-dataset configs use the first dataset's masking knobs for all datasets; + # heterogeneous per-dataset overrides are not supported in the MM path today. + ds_entries = self.cfg.datasets or [] + ds_cfg = ds_entries[0] if ds_entries else None + + def _ds_get(cfg_obj, key): + # Handle DictDefault / dict / pydantic uniformly: + # dict-style .get first, then attribute access. + if cfg_obj is None: + return None + if hasattr(cfg_obj, "get"): + try: + return cfg_obj.get(key) + except (AttributeError, KeyError, TypeError): + pass + return getattr(cfg_obj, key, None) + + roles_to_train = _ds_get(ds_cfg, "roles_to_train") + train_on_eos = _ds_get(ds_cfg, "train_on_eos") + + # cfg.role_boundaries replaces the strategy's built-in markers. + role_boundaries_override = None + if self.cfg.role_boundaries: + role_boundaries_override = list(self.cfg.role_boundaries) + + # build() calls build_collator twice (eval + train); log once. + if not is_eval: + LOG.info( + "MM collator: train_on_inputs=%s roles_to_train=%s " + "train_on_eos=%s role_boundaries_override=%s", + bool(self.cfg.train_on_inputs), + roles_to_train, + train_on_eos, + "set" if role_boundaries_override else "none", + ) + kwargs["processing_strategy"] = get_processing_strategy( self.processor, training_args.chat_template, self.cfg.chat_template, image_size=training_args.image_size, image_resize_algorithm=training_args.image_resize_algorithm, + train_on_inputs=bool(self.cfg.train_on_inputs), + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, ) elif self.cfg.batch_flattening: collator = DataCollatorWithFlattening diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index cb1f9d984..217bc765b 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -1,6 +1,7 @@ """Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types""" from copy import deepcopy +from dataclasses import dataclass, field from typing import Optional from PIL import Image, ImageOps @@ -17,9 +18,35 @@ from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +# One-shot warning dedupe so opt-out subclasses don't spam per-batch. +_ROLE_MASK_WARNED: set[str] = set() + +# Supported values for ``train_on_eos`` — mirrors the text-only +# ChatTemplateStrategy (``turn`` = trainable turn ends only, ``all`` = every +# turn end, ``none`` = never, ``last`` = only the final trainable turn end). +_VALID_TRAIN_ON_EOS = ("turn", "all", "none", "last") + + +@dataclass(frozen=True) +class RoleBoundary: + """One role's token-level span markers for the masking scanner. + + Empty ``end_tokens`` means end-of-sequence terminates the span. + """ + + role: str + start_tokens: list[int] + end_tokens: list[int] = field(default_factory=list) + include_start: bool = False + include_end: bool = True + class ProcessingStrategy: - """Base Processing Strategy class""" + """Base Processing Strategy class. + + Subclasses opt in to role masking by overriding ``_build_role_boundaries``; + otherwise only pad + media tokens are masked (legacy behavior, one-shot warned). + """ def __init__( self, @@ -27,6 +54,10 @@ class ProcessingStrategy: chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): self.processor = processor self.chat_template = chat_template @@ -38,54 +69,97 @@ class ProcessingStrategy: image_resize_algorithm or Image.Resampling.BILINEAR ) + # Defaults mirror the text-only ChatTemplateStrategy. An explicit + # empty list is honored as "no trainable roles" (masks everything); + # only ``None`` falls back to the default of assistant-only. + self.train_on_inputs = bool(train_on_inputs) + self.roles_to_train = ( + list(roles_to_train) if roles_to_train is not None else ["assistant"] + ) + self.train_on_eos = train_on_eos if train_on_eos is not None else "turn" + if self.train_on_eos not in _VALID_TRAIN_ON_EOS: + raise ValueError( + f"train_on_eos={self.train_on_eos!r} is not one of " + f"{_VALID_TRAIN_ON_EOS}." + ) + if hasattr(processor, "image_token"): self.image_token = processor.image_token self.image_token_id = processor.tokenizer.convert_tokens_to_ids( self.image_token ) + built_in = self._build_role_boundaries() + + # Truthiness check: empty list == unset (opt-in escape hatch), so + # `role_boundaries: []` in YAML falls through to built-ins instead of + # producing all-masked labels. + if role_boundaries_override: + overridden = _resolve_role_boundary_override( + role_boundaries_override, self.processor.tokenizer + ) + LOG.info( + "%s: overriding built-in role boundaries (%d decls) " + "with cfg.role_boundaries (%d decls).", + type(self).__name__, + len(built_in), + len(overridden), + ) + self.role_boundaries: list[RoleBoundary] = overridden + source = "override" + else: + self.role_boundaries = built_in + source = "built-in" + + # Single-line, grep-friendly summary of the resolved masking config so + # "why isn't masking firing?" is visible in training logs. For + # overrides we include the fully resolved (role, start_ids, end_ids) + # tuples; for built-ins we log a count (subclasses vary and logging + # every id sequence would be noisy on, e.g., Llama3 with five roles). + boundaries_repr: str | list[tuple[str, list[int], list[int]]] + if source == "override": + boundaries_repr = [ + (b.role, b.start_tokens, b.end_tokens) for b in self.role_boundaries + ] + else: + boundaries_repr = f"{len(self.role_boundaries)} built-in" + LOG.info( + "ProcessingStrategy init: class=%s train_on_inputs=%s " + "roles_to_train=%s train_on_eos=%s boundaries_source=%s " + "boundaries=%s", + type(self).__name__, + self.train_on_inputs, + self.roles_to_train, + self.train_on_eos, + source, + boundaries_repr, + ) + + def _build_role_boundaries(self) -> list[RoleBoundary]: + """Subclasses declare role boundaries here; [] opts out of role masking.""" + return [] + def __call__(self, examples: list[dict]) -> list[dict]: - """ - Preprocess conversation examples to ensure consistent format. - Converts different conversation formats to OpenAI format with 'messages'. - Supports two formats: - 1. OpenAI format with 'messages' - 2. Legacy format with 'conversations' - - Args: - examples: list of conversation dictionaries - - Returns: - list of dicts in OpenAI format with 'messages' key - - Raises: - ValueError: If the conversation format is not supported - """ + """Normalize examples to OpenAI ``messages`` format (accepts legacy ``conversations``).""" role_mapping = { "human": "user", "gpt": "assistant", } def normalize_role(role: str) -> str: - """Normalize role names to OpenAI format. Default to original role if not found.""" return role_mapping.get(role, role) def convert_legacy_format(example: dict) -> dict: - """Convert legacy 'conversations' format to OpenAI 'messages' format.""" messages = [ {"role": normalize_role(convo["from"]), "content": convo["value"]} for convo in example["conversations"] ] - - # Create new dict without 'conversations' key result = deepcopy(example) result.pop("conversations") result["messages"] = messages return result def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]: - """Convert regular messages format to Messages format with content type""" - new_messages = [] for message in messages: if isinstance(message["content"], str): @@ -119,21 +193,27 @@ class ProcessingStrategy: "Only `messages` and `conversations` message keys are currently supported." ) - processed_example = None - if ( - "messages" in example and example["messages"] is not None - ): # OpenAI format - processed_example = example - else: # Legacy format + if "messages" in example and example["messages"] is not None: + # Deepcopy for symmetry with convert_legacy_format (which + # deepcopies internally) so downstream mutations of + # processed_example don't leak back to the caller's input. + processed_example = deepcopy(example) + elif "conversations" in example: processed_example = convert_legacy_format(example) + else: + # `messages` is present but None, and no `conversations` + # fallback exists — convert_legacy_format would KeyError on + # ["conversations"]. Surface a clear validation error instead. + raise ValueError( + "`messages` is present but None; provide non-null " + "`messages` or a `conversations` field." + ) - # convert regular messages format to Messages format with content type - # for compatibility with apply_chat_template + # Required for apply_chat_template compatibility. processed_example["messages"] = convert_messages_to_multimedia_messages( processed_example["messages"] ) - # find the image key if it exists possible_image_keys = ["images", "image"] image_key = None for key in possible_image_keys: @@ -141,11 +221,8 @@ class ProcessingStrategy: image_key = key break - # if the image key exists, add the image to the first user message if image_key is not None and processed_example[image_key] is not None: - # TODO: check if it's normal to be single image only for common datasets - # From observation, it's usually a list of single image but some datasets may have several columns for images - # Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages + # TODO: support multi-image samples; for now we take the first. if len(processed_example[image_key]) > 1: LOG.warning( f"Found {len(processed_example[image_key])} images in a sample. Using the first one." @@ -155,7 +232,6 @@ class ProcessingStrategy: image_value = processed_example[image_key][0] - # Handle image loading (Image, url, path, base64) image_value = load_image(image_value) if self.image_size is not None: @@ -168,11 +244,8 @@ class ProcessingStrategy: self.image_size, self.image_resize_algorithm ) else: - # Set the padding value; here we use black (0, 0, 0) for RGB images + # Int image_size: preserve aspect ratio then pad to square (black) to avoid distortion. padding_color = (0, 0, 0) - - # When image_size is an int (square target), preserve aspect ratio then pad - # This is to prevent aspect ratio distortion when resizing to square image_value = ImageOps.pad( image_value, (self.image_size, self.image_size), @@ -180,8 +253,6 @@ class ProcessingStrategy: color=padding_color, ) - # Look for any image type in the first message - # some dataset have an {type: "image"} in the first message msg_ind_to_add = None ind_to_add = None first_user_idx = None @@ -192,7 +263,7 @@ class ProcessingStrategy: for i, content in enumerate( processed_example["messages"][msg_idx]["content"] ): - # Usually datasets created with image columns, don't have it in the messages itself + # Column-image datasets often leave a bare {type: "image"} placeholder. if content["type"] == "image" and all( k not in content for k in ["image", "url", "path", "base64"] ): @@ -200,13 +271,11 @@ class ProcessingStrategy: ind_to_add = i break - # If an image type is found, add the image to that index if ind_to_add is not None and msg_ind_to_add is not None: processed_example["messages"][msg_ind_to_add]["content"][ ind_to_add ]["image"] = image_value else: - # if no image type is found, add it to end of the first user message if first_user_idx is None: first_user_idx = 0 processed_example["messages"][first_user_idx]["content"].append( @@ -221,28 +290,224 @@ class ProcessingStrategy: return processed_examples def _mask_non_assistant(self, labels: Tensor) -> Tensor: - """ - Mask non assistant regions to -100. - To be implemented per subclass. - """ - return labels + """Mask non-trainable role regions to -100 using ``self.role_boundaries``.""" + if self.train_on_inputs: + return labels + + # Legacy no-op for boundary-less strategies; warn once so the miss shows up in logs. + if not self.role_boundaries: + key = type(self).__name__ + if key not in _ROLE_MASK_WARNED: + _ROLE_MASK_WARNED.add(key) + LOG.warning( + "%s has no built-in role boundaries; " + "cfg.train_on_inputs / cfg.roles_to_train / cfg.train_on_eos " + "will NOT restrict loss to assistant tokens for this " + "multimodal model — only pad and media tokens are masked, " + "every other token (system, user, assistant) contributes " + "to loss. To enable assistant-only masking, declare " + "per-role markers in YAML via cfg.role_boundaries — see " + "docs/multimodal_assistant_mask.md for the format and the " + "list of strategies on this fallback path.", + key, + ) + return labels + + return _apply_role_boundaries( + labels, + self.role_boundaries, + roles_to_train=set(self.roles_to_train), + train_on_eos=self.train_on_eos, + ) def process_labels(self, input_ids: Tensor) -> Tensor: labels = input_ids.clone() - labels = self._mask_non_assistant(labels) - - # The labels are the input_ids, and we mask the padding tokens in the loss computation - labels[labels == self.processor.tokenizer.pad_token_id] = -100 - - # Ignore the image token index in the loss computation (model specific) - labels[labels == self.image_token_id] = -100 - + pad_id = getattr(self.processor.tokenizer, "pad_token_id", None) + if pad_id is not None: + labels[labels == pad_id] = -100 + if self.image_token_id is not None: + labels[labels == self.image_token_id] = -100 return labels +def _apply_role_boundaries( + labels: Tensor, + role_boundaries: list[RoleBoundary], + roles_to_train: set[str], + train_on_eos: str, +) -> Tensor: + """Mask tokens outside trainable role spans to -100. + + Scan is greedy-left with longest-prefix-wins on start_tokens to disambiguate + nested markers (e.g. ``<|im_start|>assistant`` vs ``<|im_start|>``). + ``train_on_eos`` accepts ``"turn"`` (end marker in loss on trainable turns + only), ``"all"`` (always), ``"none"`` (never — overrides ``include_end``), + ``"last"`` (only on the last trainable turn in the sequence). + """ + mask = zeros_like(labels) + # For "last": remember each trainable turn's end-marker span so we can + # unmask only the final one after the scan finishes. + last_trainable_end_span: list[Optional[tuple[int, int]]] = [None] * labels.shape[0] + + # Work on a Python list per row — avoids O(n*boundaries) Tensor→list + # conversions in the hot prefix-match loop. + def _match_prefix(label: list[int], start_pos: int, tok_seq: list[int]) -> bool: + if not tok_seq or start_pos + len(tok_seq) > len(label): + return False + return label[start_pos : start_pos + len(tok_seq)] == tok_seq + + def _find_end( + label: list[int], start_pos: int, end_tok: list[int] + ) -> tuple[int, bool]: + # Empty end_tok means run to end-of-sequence. + if not end_tok: + return len(label), False + k = start_pos + while k < len(label): + if _match_prefix(label, k, end_tok): + return k + len(end_tok), True + k += 1 + return k, False + + for i in range(labels.shape[0]): + label = labels[i].tolist() + j = 0 + n = len(label) + while j < n: + best_match: Optional[RoleBoundary] = None + for b in role_boundaries: + if _match_prefix(label, j, b.start_tokens): + if best_match is None or len(b.start_tokens) > len( + best_match.start_tokens + ): + best_match = b + if best_match is None: + j += 1 + continue + + start_of_content = j + len(best_match.start_tokens) + end_after, found_end = _find_end( + label, start_of_content, best_match.end_tokens + ) + + role_in_loss = best_match.role in roles_to_train + + if role_in_loss: + if best_match.include_start: + mask[i][j:start_of_content] = 1 + content_end = ( + end_after - len(best_match.end_tokens) if found_end else end_after + ) + mask[i][start_of_content:content_end] = 1 + # train_on_eos="none"/"last" override include_end during main + # loop; "last" is applied after the scan finishes. + if ( + found_end + and best_match.include_end + and train_on_eos not in ("none", "last") + ): + mask[i][content_end:end_after] = 1 + if found_end and best_match.include_end and train_on_eos == "last": + last_trainable_end_span[i] = (content_end, end_after) + else: + # Non-trainable role on train_on_eos="all": gate on include_end + # so Pixtral / Mistral V7 Tekken shared [/INST] doesn't leak. + if found_end and best_match.include_end and train_on_eos == "all": + content_end = end_after - len(best_match.end_tokens) + mask[i][content_end:end_after] = 1 + + # When include_end=False, do not consume the end marker: back up so + # the next iteration can re-match it as the next boundary's start + # marker (Pixtral / Mistral V7 Tekken share [/INST] between + # user-end and assistant-start). Requires end_tokens non-empty and + # actually found. + if found_end and not best_match.include_end and best_match.end_tokens: + j = end_after - len(best_match.end_tokens) + else: + j = end_after + + if train_on_eos == "last" and (span := last_trainable_end_span[i]) is not None: + s, e = span + mask[i][s:e] = 1 + + labels[i][mask[i] == 0] = -100 + + return labels + + +def _encode_markers(tokenizer, marker_strs: list[str]) -> list[list[int]]: + """Encode markers via ``encode(..., add_special_tokens=False)``; drops empty results.""" + result = [] + for s in marker_strs: + toks = tokenizer.encode(s, add_special_tokens=False) + if toks: + result.append(toks) + return result + + +def _resolve_role_boundary_override(specs: list[dict], tokenizer) -> list[RoleBoundary]: + """Resolve user ``cfg.role_boundaries`` specs into RoleBoundary objects. + + The sentinel ``end == "eos_token"`` resolves to ``eos_token_id`` (used by + Pixtral/Mistral v7 templates). ``end`` null/omitted runs to end-of-sequence. + """ + out: list[RoleBoundary] = [] + for i, spec in enumerate(specs): + if hasattr(spec, "model_dump"): + d = spec.model_dump() + else: + d = dict(spec) + + role = d.get("role") + start_str = d.get("start") + if not role or start_str is None: + raise ValueError( + f"cfg.role_boundaries[{i}] must have both 'role' and 'start' " + f"(got {d!r})." + ) + start_ids = tokenizer.encode(start_str, add_special_tokens=False) + if not start_ids: + raise ValueError( + f"cfg.role_boundaries[{i}]: start marker {start_str!r} " + f"tokenizes to an empty sequence; cannot match." + ) + + end_spec = d.get("end") + if end_spec is None: + end_ids: list[int] = [] + elif end_spec == "eos_token": + eos = getattr(tokenizer, "eos_token_id", None) + if eos is None: + raise ValueError( + f"cfg.role_boundaries[{i}] requested end='eos_token' but " + "the tokenizer has no eos_token_id." + ) + end_ids = [eos] + else: + end_ids = tokenizer.encode(end_spec, add_special_tokens=False) + if not end_ids: + raise ValueError( + f"cfg.role_boundaries[{i}]: end marker {end_spec!r} " + f"tokenizes to an empty sequence; cannot match. Use " + f"end=null to run to end-of-sequence or end='eos_token' " + f"to terminate at the tokenizer's EOS." + ) + + out.append( + RoleBoundary( + role=role, + start_tokens=start_ids, + end_tokens=end_ids, + include_start=bool(d.get("include_start", False)), + include_end=bool(d.get("include_end", True)), + ) + ) + return out + + class Qwen2VLProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for Qwen2-VL""" + """Processing Strategy class for Qwen2-VL (ChatML ``<|im_start|>{role}\\n ... <|im_end|>``).""" def __init__( self, @@ -250,16 +515,44 @@ class Qwen2VLProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, + ) self.image_token = "<|image_pad|>" # nosec self.image_token_id = processor.tokenizer.convert_tokens_to_ids( self.image_token ) + def _build_role_boundaries(self) -> list[RoleBoundary]: + tok = self.processor.tokenizer + end = _encode_markers(tok, ["<|im_end|>"]) + if not end: + return [] + end_ids = end[0] + boundaries = [] + for role in ("system", "user", "assistant"): + start = _encode_markers(tok, [f"<|im_start|>{role}\n"]) + if start: + boundaries.append( + RoleBoundary(role=role, start_tokens=start[0], end_tokens=end_ids) + ) + return boundaries -class Qwen3_5ProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for Qwen3.5 (early-fusion VLM)""" + +class Qwen3_5ProcessingStrategy(Qwen2VLProcessingStrategy): + """Processing Strategy class for Qwen3.5 (Qwen2-VL boundaries + ``<|video_pad|>`` mask).""" def __init__( self, @@ -267,11 +560,20 @@ class Qwen3_5ProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) - self.image_token = "<|image_pad|>" # nosec - self.image_token_id = processor.tokenizer.convert_tokens_to_ids( - self.image_token + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, ) self.video_token = "<|video_pad|>" # nosec self.video_token_id = processor.tokenizer.convert_tokens_to_ids( @@ -280,12 +582,44 @@ class Qwen3_5ProcessingStrategy(ProcessingStrategy): def process_labels(self, input_ids): labels = super().process_labels(input_ids) - labels[labels == self.video_token_id] = -100 + if self.video_token_id is not None: + labels[labels == self.video_token_id] = -100 return labels -class Gemma3ProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for Gemma3""" +class _GemmaTurnStrategy(ProcessingStrategy): + """Gemma3/3n ``{role} ... `` (Gemma 4 uses different markers).""" + + def _build_role_boundaries(self) -> list[RoleBoundary]: + tok = self.processor.tokenizer + end = _encode_markers(tok, [""]) + if not end: + return [] + end_ids = end[0] + boundaries = [] + # Template uses 'model'; external role knob stays 'assistant'. Gemma 3 + # and Gemma 3n jinja templates fold the system message into the first + # user's content prefix and never emit 'system', so we + # don't declare a system boundary here. + role_marker_pairs = [ + ("assistant", "model"), + ("user", "user"), + ] + for external_role, template_role in role_marker_pairs: + start = _encode_markers(tok, [f"{template_role}\n"]) + if start: + boundaries.append( + RoleBoundary( + role=external_role, + start_tokens=start[0], + end_tokens=end_ids, + ) + ) + return boundaries + + +class Gemma3ProcessingStrategy(_GemmaTurnStrategy): + """Processing Strategy class for Gemma3.""" def __init__( self, @@ -293,119 +627,247 @@ class Gemma3ProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) - self.image_token = processor.tokenizer.special_tokens_map["boi_token"] - self.image_token_id = processor.tokenizer.convert_tokens_to_ids( - self.image_token + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, ) + # Real Gemma3 tokenizers expose boi_token as a direct attribute, not + # via special_tokens_map (which only holds HF's standard slots). + boi = getattr(processor.tokenizer, "boi_token", None) + if boi is not None: + self.image_token = boi + self.image_token_id = processor.tokenizer.convert_tokens_to_ids(boi) def process_labels(self, input_ids): - labels = input_ids.clone() - - # Follows https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora - labels[labels == self.processor.tokenizer.pad_token_id] = -100 - labels[labels == self.image_token_id] = -100 - labels[labels == 262144] = -100 # corresponds to - + labels = super().process_labels(input_ids) + # Resolve via tokenizer; fall back to default id + # if not in vocab. Matches Gemma4's pattern. + tok = self.processor.tokenizer + soft_id = tok.convert_tokens_to_ids("") + unk_id = getattr(tok, "unk_token_id", None) + if soft_id is not None and soft_id != unk_id: + labels[labels == soft_id] = -100 + else: + labels[labels == 262144] = -100 return labels -class Gemma3nProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for Gemma3n""" +class Gemma3nProcessingStrategy(_GemmaTurnStrategy): + """Gemma3n: same turn boundaries as Gemma3, additionally masks audio/delimiter tokens.""" - def _mask_non_assistant(self, labels: Tensor) -> Tensor: - def _find_token_sequence(label, start_pos, token_sequence): - """Check if token_sequence appears at start_pos in label""" - if start_pos + len(token_sequence) > len(label): - return False - if label[start_pos] != token_sequence[0]: - return False - return ( - label[start_pos : start_pos + len(token_sequence)].tolist() - == token_sequence - ) + def process_labels(self, input_ids): + labels = super().process_labels(input_ids) + tok = self.processor.tokenizer + # Follows huggingface-gemma-recipes fine_tune_gemma3n_on_t4 notebook. + for attr in ( + "image_token_id", + "audio_token_id", + "boi_token_id", + "eoi_token_id", + ): + tok_id = getattr(tok, attr, None) + if tok_id is not None: + labels[labels == tok_id] = -100 + return labels - def _find_assistant_end(label, start_pos, assistant_end_tok, mask, i): - """ - Find the end of assistant response and update mask accordingly - Returns new position to continue from and whether the end seq is found - """ - k = start_pos - while k < len(label): - if not _find_token_sequence(label, k, assistant_end_tok): - mask[i][k] = 1 - k += 1 - continue +class Gemma4ProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Gemma 4. - return k + len(assistant_end_tok), True + Boundary markers ``<|turn>model ... `` verified against + google/gemma-4-E2B-it. boi/eoi/boa/eoa ids are resolved via + ``convert_tokens_to_ids`` since only their string forms are on the processor. + """ - return k, False - - mask = zeros_like(labels) - - assistant_start_str = "model" - assistant_end_str = "" - include_assistant_start_tok = False - include_assistant_end_tok = True - - # str to tokens - assistant_start_tok = self.processor.tokenizer.encode( - assistant_start_str, add_special_tokens=False - ) - assistant_end_tok = self.processor.tokenizer.encode( - assistant_end_str, add_special_tokens=False - ) - - for i, label in enumerate(labels): - j = 0 - # while loop through each tok index in labels[i] - while j < len(label): - # Check until match start seq - if not _find_token_sequence(label, j, assistant_start_tok): - j += 1 - continue - - if include_assistant_start_tok: - mask[i][j : j + len(assistant_start_tok)] = 1 - - # Find where the assistant response ends - start_of_content = j + len(assistant_start_tok) - end_pos, found_end_seq = _find_assistant_end( - label, start_of_content, assistant_end_tok, mask, i + def _build_role_boundaries(self) -> list[RoleBoundary]: + tok = self.processor.tokenizer + end = _encode_markers(tok, [""]) + if not end: + return [] + end_ids = end[0] + boundaries = [] + role_marker_pairs = [ + ("assistant", "model"), + ("user", "user"), + ("system", "system"), + ] + for external_role, template_role in role_marker_pairs: + # Include trailing ``\n`` for consistency with Qwen/Gemma3/Llama + # markers; the newline is part of the marker in the real + # google/gemma-4 tokenizer's chat template. + start = _encode_markers(tok, [f"<|turn>{template_role}\n"]) + if start: + boundaries.append( + RoleBoundary( + role=external_role, + start_tokens=start[0], + end_tokens=end_ids, + ) ) - - # Include end token if requested - if include_assistant_end_tok and found_end_seq: - mask[i][end_pos - len(assistant_end_tok) : end_pos] = 1 - - j = end_pos - - labels[i][mask[i] == 0] = -100 - - return labels + return boundaries def process_labels(self, input_ids): - labels = input_ids.clone() - labels = self._mask_non_assistant(labels) + labels = super().process_labels(input_ids) - # Follows https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/fine_tune_gemma3n_on_t4.ipynb - labels[labels == self.processor.tokenizer.pad_token_id] = -100 - if hasattr(self.processor.tokenizer, "image_token_id"): - labels[labels == self.processor.tokenizer.image_token_id] = -100 - if hasattr(self.processor.tokenizer, "audio_token_id"): - labels[labels == self.processor.tokenizer.audio_token_id] = -100 - if hasattr(self.processor.tokenizer, "boi_token_id"): - labels[labels == self.processor.tokenizer.boi_token_id] = -100 - if hasattr(self.processor.tokenizer, "eoi_token_id"): - labels[labels == self.processor.tokenizer.eoi_token_id] = -100 + tokenizer = self.processor.tokenizer + unk_id = getattr(tokenizer, "unk_token_id", None) + + if getattr(tokenizer, "image_token_id", None) is not None: + labels[labels == tokenizer.image_token_id] = -100 + if getattr(tokenizer, "audio_token_id", None) is not None: + labels[labels == tokenizer.audio_token_id] = -100 + + # boi/eoi/boa/eoa are only string attrs on the processor; resolve ids here. + for attr in ("boi_token", "eoi_token", "boa_token", "eoa_token"): + token_str = getattr(self.processor, attr, None) + if token_str is None: + continue + token_id = tokenizer.convert_tokens_to_ids(token_str) + if token_id is None or token_id == unk_id: + continue + labels[labels == token_id] = -100 + + # Video id lives on the processor, not the tokenizer. + video_token_id = getattr(self.processor, "video_token_id", None) + if video_token_id is not None and video_token_id != unk_id: + labels[labels == video_token_id] = -100 return labels +class Llama3_2VisionProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Llama-3.2 Vision (``<|start_header_id|>{role}<|end_header_id|>\\n\\n ... <|eot_id|>``).""" + + def _build_role_boundaries(self) -> list[RoleBoundary]: + tok = self.processor.tokenizer + end = _encode_markers(tok, ["<|eot_id|>"]) + if not end: + return [] + end_ids = end[0] + boundaries = [] + for role in ("system", "user", "assistant", "ipython", "tool"): + start = _encode_markers( + tok, [f"<|start_header_id|>{role}<|end_header_id|>\n\n"] + ) + if start: + boundaries.append( + RoleBoundary(role=role, start_tokens=start[0], end_tokens=end_ids) + ) + return boundaries + + +class Llama4ProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Llama 4 (``<|header_start|>{role}<|header_end|>\\n\\n ... <|eot|>``).""" + + def _build_role_boundaries(self) -> list[RoleBoundary]: + tok = self.processor.tokenizer + end = _encode_markers(tok, ["<|eot|>"]) + if not end: + return [] + end_ids = end[0] + boundaries = [] + for role in ("system", "user", "assistant", "ipython", "tool"): + start = _encode_markers(tok, [f"<|header_start|>{role}<|header_end|>\n\n"]) + if start: + boundaries.append( + RoleBoundary(role=role, start_tokens=start[0], end_tokens=end_ids) + ) + return boundaries + + +class PixtralProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Pixtral (``[INST] ... [/INST]`` user, assistant terminates at ``eos_token``). + + ``[/INST]`` is shared between user-end and assistant-start. We declare user + with ``include_end=False`` so the scanner hands the ``[/INST]`` back to + assistant's start match on the next iteration. + """ + + def _build_role_boundaries(self) -> list[RoleBoundary]: + tok = self.processor.tokenizer + eos = getattr(tok, "eos_token_id", None) + if eos is None: + return [] + boundaries = [] + inst_start = _encode_markers(tok, ["[INST]"]) + inst_end = _encode_markers(tok, ["[/INST]"]) + if inst_start and inst_end: + boundaries.append( + RoleBoundary( + role="user", + start_tokens=inst_start[0], + end_tokens=inst_end[0], + include_end=False, + ) + ) + boundaries.append( + RoleBoundary( + role="assistant", + start_tokens=inst_end[0], + end_tokens=[eos], + ) + ) + return boundaries + + +class MistralV7TekkenProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Mistral v7 Tekken (Pixtral-style plus ``[SYSTEM_PROMPT]...[/SYSTEM_PROMPT]``). + + Same ``[/INST]``-shared-marker treatment as :class:`PixtralProcessingStrategy`. + """ + + def _build_role_boundaries(self) -> list[RoleBoundary]: + tok = self.processor.tokenizer + eos = getattr(tok, "eos_token_id", None) + if eos is None: + return [] + boundaries = [] + sys_start = _encode_markers(tok, ["[SYSTEM_PROMPT]"]) + sys_end = _encode_markers(tok, ["[/SYSTEM_PROMPT]"]) + if sys_start and sys_end: + boundaries.append( + RoleBoundary( + role="system", start_tokens=sys_start[0], end_tokens=sys_end[0] + ) + ) + inst_start = _encode_markers(tok, ["[INST]"]) + inst_end = _encode_markers(tok, ["[/INST]"]) + if inst_start and inst_end: + boundaries.append( + RoleBoundary( + role="user", + start_tokens=inst_start[0], + end_tokens=inst_end[0], + include_end=False, + ) + ) + boundaries.append( + RoleBoundary( + role="assistant", + start_tokens=inst_end[0], + end_tokens=[eos], + ) + ) + return boundaries + + class VoxtralProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for Voxtral""" + """Processing Strategy class for Voxtral. + + Role boundaries NOT declared — mistral-common instruct tokenizer markers + unverified. Falls back to pad+audio masking with a one-shot warning. + """ def __init__( self, @@ -413,8 +875,21 @@ class VoxtralProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, + ) special_ids = ( processor.tokenizer.tokenizer.instruct_tokenizer.audio_encoder.special_ids ) @@ -424,16 +899,25 @@ class VoxtralProcessingStrategy(ProcessingStrategy): def process_labels(self, input_ids): labels = input_ids.clone() + labels = self._mask_non_assistant(labels) - labels[labels == self.processor.tokenizer.pad_token_id] = -100 - labels[labels == self.audio_token] = -100 - labels[labels == self.begin_audio_token] = -100 + pad_id = getattr(self.processor.tokenizer, "pad_token_id", None) + if pad_id is not None: + labels[labels == pad_id] = -100 + if self.audio_token is not None: + labels[labels == self.audio_token] = -100 + if self.begin_audio_token is not None: + labels[labels == self.begin_audio_token] = -100 return labels class SmolVLM2ProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for SmolVLM2""" + """Processing Strategy class for SmolVLM2. + + Role boundaries NOT declared — SmolVLM2 chat_template varies per checkpoint + (HuggingFaceTB ships multiple variants), so we opt out rather than mis-mask. + """ def __init__( self, @@ -441,8 +925,21 @@ class SmolVLM2ProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, + ) self.image_token = "" # nosec self.image_token_id = processor.tokenizer.additional_special_tokens_ids[ @@ -451,7 +948,11 @@ class SmolVLM2ProcessingStrategy(ProcessingStrategy): class Mistral3ProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for Mistral3""" + """Processing Strategy class for Mistral3. + + Role boundaries NOT declared (mistral-common instruct tokenizer unverified); + same fallback as VoxtralProcessingStrategy. + """ def __init__( self, @@ -459,8 +960,21 @@ class Mistral3ProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, + ) special_ids = ( processor.tokenizer.tokenizer.instruct_tokenizer.image_encoder.special_ids ) @@ -471,17 +985,24 @@ class Mistral3ProcessingStrategy(ProcessingStrategy): def process_labels(self, input_ids): labels = input_ids.clone() + labels = self._mask_non_assistant(labels) - labels[labels == self.processor.tokenizer.pad_token_id] = -100 - labels[labels == self.image_token] = -100 - labels[labels == self.image_break_token] = -100 - labels[labels == self.image_end_token] = -100 + pad_id = getattr(self.processor.tokenizer, "pad_token_id", None) + if pad_id is not None: + labels[labels == pad_id] = -100 + for tok_id in (self.image_token, self.image_break_token, self.image_end_token): + if tok_id is not None: + labels[labels == tok_id] = -100 return labels class InternVLProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for InternVL""" + """Processing Strategy class for InternVL. + + Role boundaries NOT declared (InternLM-style template unverified); falls + back to pad + image-id masking with a one-shot warning. + """ def __init__( self, @@ -489,8 +1010,21 @@ class InternVLProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, + ) if not hasattr(processor, "image_ids"): raise ValueError("'image_ids' missing from InternVL Processor.") @@ -499,20 +1033,26 @@ class InternVLProcessingStrategy(ProcessingStrategy): def process_labels(self, input_ids): labels = input_ids.clone() + labels = self._mask_non_assistant(labels) - labels[labels == self.processor.tokenizer.pad_token_id] = -100 + pad_id = getattr(self.processor.tokenizer, "pad_token_id", None) + if pad_id is not None: + labels[labels == pad_id] = -100 for ids in self.image_token_ids: - labels[labels == ids] = -100 - - # Note: Check if need to mask 'video_token' as it gets converted to - # image patches during media processing + if ids is not None: + labels[labels == ids] = -100 + # Video tokens get converted to image patches during media processing; masking may be redundant. return labels class Glm4vProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for GLM4V and GLM4V-MoE vision models.""" + """Shared strategy for Glm4vProcessor (GLM-4V / GLM-4.1V) and + Glm46VProcessor (GLM-4.6V / GLM-4.7V) — identical media-token markers. + + Role boundaries unverified; use cfg.role_boundaries to enable masking. + """ def __init__( self, @@ -520,8 +1060,21 @@ class Glm4vProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, + ) self.tokenizer = getattr(processor, "tokenizer", processor) @@ -549,16 +1102,22 @@ class Glm4vProcessingStrategy(ProcessingStrategy): def process_labels(self, input_ids): labels = input_ids.clone() + labels = self._mask_non_assistant(labels) - labels[labels == self.tokenizer.pad_token_id] = -100 + pad_id = getattr(self.tokenizer, "pad_token_id", None) + if pad_id is not None: + labels[labels == pad_id] = -100 - labels[labels == self.image_token_id] = -100 - labels[labels == self.begin_image_token_id] = -100 - labels[labels == self.end_image_token_id] = -100 - - labels[labels == self.video_token_id] = -100 - labels[labels == self.begin_video_token_id] = -100 - labels[labels == self.end_video_token_id] = -100 + for tok_id in ( + self.image_token_id, + self.begin_image_token_id, + self.end_image_token_id, + self.video_token_id, + self.begin_video_token_id, + self.end_video_token_id, + ): + if tok_id is not None: + labels[labels == tok_id] = -100 return labels @@ -569,14 +1128,20 @@ def get_processing_strategy( chat_template_type, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - from axolotl.utils.mistral.mistral3_processor import Mistral3Processor - processing_kwargs = { "processor": processor, "chat_template": chat_template, "image_size": image_size, "image_resize_algorithm": image_resize_algorithm, + "train_on_inputs": train_on_inputs, + "roles_to_train": roles_to_train, + "train_on_eos": train_on_eos, + "role_boundaries_override": role_boundaries_override, } if chat_template_type in [None, "tokenizer_default"]: @@ -585,53 +1150,63 @@ def get_processing_strategy( processing_kwargs["chat_template"] = tokenizer.chat_template if chat_template_type == "qwen2_vl": - return Qwen2VLProcessingStrategy( - **processing_kwargs, - ) - if chat_template_type in ["qwen3_5", "qwen3_5_moe"]: - return Qwen3_5ProcessingStrategy( - **processing_kwargs, - ) + return Qwen2VLProcessingStrategy(**processing_kwargs) + if chat_template_type == "qwen3_5": + return Qwen3_5ProcessingStrategy(**processing_kwargs) if chat_template_type == "gemma3": - return Gemma3ProcessingStrategy( - **processing_kwargs, - ) + return Gemma3ProcessingStrategy(**processing_kwargs) if chat_template_type == "gemma3n": - return Gemma3nProcessingStrategy( - **processing_kwargs, - ) + return Gemma3nProcessingStrategy(**processing_kwargs) + if chat_template_type == "gemma4": + return Gemma4ProcessingStrategy(**processing_kwargs) + if chat_template_type == "llama3_2_vision": + return Llama3_2VisionProcessingStrategy(**processing_kwargs) + if chat_template_type == "llama4": + return Llama4ProcessingStrategy(**processing_kwargs) + if chat_template_type == "pixtral": + return PixtralProcessingStrategy(**processing_kwargs) + if chat_template_type == "mistral_v7_tekken": + return MistralV7TekkenProcessingStrategy(**processing_kwargs) if isinstance(processor, VoxtralProcessor): - return VoxtralProcessingStrategy( - **processing_kwargs, - ) + return VoxtralProcessingStrategy(**processing_kwargs) if isinstance(processor, SmolVLMProcessor): - return SmolVLM2ProcessingStrategy( - **processing_kwargs, + return SmolVLM2ProcessingStrategy(**processing_kwargs) + + # Lazy import: mistral_common is optional. Mirrors the Glm46V pattern below. + try: + from axolotl.utils.mistral.mistral3_processor import Mistral3Processor + + if isinstance(processor, Mistral3Processor): + return Mistral3ProcessingStrategy(**processing_kwargs) + except (ImportError, ModuleNotFoundError) as exc: + LOG.debug( + "Mistral3Processor import failed; Mistral3 strategy will be unavailable: %r", + exc, ) - if isinstance(processor, Mistral3Processor): - return Mistral3ProcessingStrategy( - **processing_kwargs, - ) + # Both Glm4vProcessor and Glm46VProcessor share markers; route to the same + # strategy. Independent try/except so either can be absent. + try: + from transformers.models.glm4v.processing_glm4v import Glm4vProcessor + + if isinstance(processor, Glm4vProcessor): + return Glm4vProcessingStrategy(**processing_kwargs) + except (ImportError, ModuleNotFoundError) as exc: + LOG.debug("Glm4vProcessor import failed: %r", exc) + try: from transformers.models.glm46v.processing_glm46v import Glm46VProcessor if isinstance(processor, Glm46VProcessor): - return Glm4vProcessingStrategy( - **processing_kwargs, - ) - except ImportError: - pass + return Glm4vProcessingStrategy(**processing_kwargs) + except (ImportError, ModuleNotFoundError) as exc: + LOG.debug("Glm46VProcessor import failed: %r", exc) if isinstance(processor, InternVLProcessor): - return InternVLProcessingStrategy( - **processing_kwargs, - ) + return InternVLProcessingStrategy(**processing_kwargs) - # llama3_2_vision, llama4, llava - # mistral_v7_tekken, pixtral, lfm2vl - return ProcessingStrategy( - **processing_kwargs, - ) + # Unregistered templates (llava, lfm2vl, mistral_v3_tekken, ...) use the + # base strategy; it warns once when train_on_inputs=False. + return ProcessingStrategy(**processing_kwargs) diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py index 6114a63e0..97ed71631 100644 --- a/src/axolotl/utils/schemas/datasets.py +++ b/src/axolotl/utils/schemas/datasets.py @@ -166,10 +166,10 @@ class SFTDataset(BaseModel): "description": "Roles to train on. The tokens from these roles will be considered for the loss." }, ) - train_on_eos: Literal["all", "turn", "last"] | None = Field( + train_on_eos: Literal["all", "turn", "last", "none"] | None = Field( default=None, json_schema_extra={ - "description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation" + "description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation, none: never train on EOS tokens" }, ) roles: dict[str, list[str]] | None = Field( diff --git a/src/axolotl/utils/schemas/multimodal.py b/src/axolotl/utils/schemas/multimodal.py index a3449199f..01ad5e5a3 100644 --- a/src/axolotl/utils/schemas/multimodal.py +++ b/src/axolotl/utils/schemas/multimodal.py @@ -6,6 +6,57 @@ from PIL.Image import Resampling from pydantic import BaseModel, Field, field_validator +class RoleBoundarySpec(BaseModel): + """One ``cfg.role_boundaries`` row; see docs/multimodal_assistant_mask.md.""" + + role: str = Field( + json_schema_extra={ + "description": ( + "Role name as it appears in cfg.roles_to_train (e.g. " + "'assistant', 'user', 'system', 'tool', 'ipython')." + ) + }, + ) + start: str = Field( + json_schema_extra={ + "description": ( + "Literal string that marks the start of this role's span in " + "the rendered chat template. Tokenized via " + "``tokenizer.encode(..., add_special_tokens=False)`` at " + "strategy init." + ) + }, + ) + end: str | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Literal string that marks the end of this role's span. " + "Set to ``eos_token`` to terminate at the tokenizer's EOS. " + "Leave unset / null to terminate at end-of-sequence." + ) + }, + ) + include_start: bool = Field( + default=False, + json_schema_extra={ + "description": ( + "Whether the start marker tokens contribute to loss on " + "trainable turns. Default False." + ) + }, + ) + include_end: bool = Field( + default=True, + json_schema_extra={ + "description": ( + "Whether the end marker tokens contribute to loss on " + "trainable turns (honoring cfg.train_on_eos). Default True." + ) + }, + ) + + class MultiModalConfig(BaseModel): """Multi-modal configuration subset""" @@ -26,6 +77,17 @@ class MultiModalConfig(BaseModel): "description": "The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details." }, ) + role_boundaries: list[RoleBoundarySpec] | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Opt-in override for the MM mask scanner's per-role boundary " + "markers. Non-empty list replaces built-ins wholesale; unset " + "or empty falls back to built-ins. See " + "docs/multimodal_assistant_mask.md." + ) + }, + ) @field_validator("image_resize_algorithm", mode="before") @classmethod diff --git a/tests/test_processing_strategies.py b/tests/test_processing_strategies.py new file mode 100644 index 000000000..2d8f13fe5 --- /dev/null +++ b/tests/test_processing_strategies.py @@ -0,0 +1,1164 @@ +"""Tests for ``axolotl.processing_strategies`` using fake tokenizers (offline/CI-safe).""" + +import logging + +import pytest +import torch +from pydantic import ValidationError + +from axolotl.processing_strategies import ( + Gemma3nProcessingStrategy, + Gemma3ProcessingStrategy, + Gemma4ProcessingStrategy, + Llama3_2VisionProcessingStrategy, + Llama4ProcessingStrategy, + MistralV7TekkenProcessingStrategy, + PixtralProcessingStrategy, + ProcessingStrategy, + Qwen2VLProcessingStrategy, + Qwen3_5ProcessingStrategy, + RoleBoundary, + _apply_role_boundaries, + get_processing_strategy, +) + + +@pytest.fixture +def axolotl_caplog(caplog): + """caplog that also captures records from the ``axolotl`` logger. + + The axolotl logger sets ``propagate=False`` once ``configure_logging()`` is + called (which happens indirectly in many CI test paths), so the default + caplog handler installed on the root logger never sees these records. + Attaching ``caplog.handler`` to ``axolotl.processing_strategies`` directly + makes assertions reliable regardless of whether ``configure_logging`` has + already run on this worker. + """ + logger = logging.getLogger("axolotl.processing_strategies") + logger.addHandler(caplog.handler) + previous_level = logger.level + logger.setLevel(logging.DEBUG) + try: + yield caplog + finally: + logger.removeHandler(caplog.handler) + logger.setLevel(previous_level) + + +# --------------------------------------------------------------------------- # +# Generic fake tokenizer/processor scaffold +# --------------------------------------------------------------------------- # + + +class _Tokenizer: + """Minimal tokenizer stub; ``vocab`` maps marker strings to their id lists.""" + + def __init__( + self, + vocab: dict[str, list[int]], + pad_id: int = 0, + unk_id: int = 3, + eos_id: int | None = None, + ): + self.vocab = vocab + self._reverse = {} + for tok, ids in vocab.items(): + if len(ids) == 1: + self._reverse[ids[0]] = tok + self.pad_token_id = pad_id + self.unk_token_id = unk_id + if eos_id is not None: + self.eos_token_id = eos_id + + def encode(self, text, add_special_tokens=False): + # Unknown markers return [] so _encode_markers drops them silently. + return list(self.vocab.get(text, [])) + + def convert_tokens_to_ids(self, token): + v = self.vocab.get(token) + if v is None: + return self.unk_token_id + return v[0] if len(v) == 1 else self.unk_token_id + + +class _Processor: + def __init__(self, tokenizer: _Tokenizer): + self.tokenizer = tokenizer + + +# --------------------------------------------------------------------------- # +# Base scanner tests (train_on_inputs / roles_to_train / train_on_eos) +# --------------------------------------------------------------------------- # + + +def _scan(role_boundaries, seq, roles_to_train=("assistant",), train_on_eos="turn"): + labels = torch.tensor([seq]) + return _apply_role_boundaries( + labels, role_boundaries, set(roles_to_train), train_on_eos + ).tolist()[0] + + +def test_scanner_assistant_only_basic(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]), + ] + seq = [1, 3, 7, 9, 1, 2, 8, 8, 9, 5] + out = _scan(boundaries, seq) + assert out == [-100, -100, -100, -100, -100, -100, 8, 8, 9, -100] + + +def test_scanner_train_on_eos_none_excludes_end_marker(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + ] + seq = [1, 2, 8, 8, 9] + out = _scan(boundaries, seq, train_on_eos="none") + assert out == [-100, -100, 8, 8, -100] + + +def test_scanner_train_on_eos_all_keeps_non_assistant_end_marker(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]), + ] + seq = [1, 3, 7, 9, 1, 2, 8, 9] + out = _scan(boundaries, seq, train_on_eos="all") + assert out == [-100, -100, -100, 9, -100, -100, 8, 9] + + +def test_scanner_train_on_eos_all_with_non_trainable_include_end_false(): + """Non-trainable + include_end=False must not leak end marker on 'all'.""" + boundaries = [ + RoleBoundary( + role="user", + start_tokens=[50], + end_tokens=[51], + include_end=False, # shared with assistant-start + ), + RoleBoundary( + role="assistant", + start_tokens=[51], + end_tokens=[99], # eos + ), + ] + seq = [50, 7, 51, 8, 8, 99] # [INST] 7 [/INST] 8 8 EOS + out = _scan(boundaries, seq, roles_to_train=("assistant",), train_on_eos="all") + # [/INST] at idx 2 must stay masked — user.include_end=False says so. + assert out == [-100, -100, -100, 8, 8, 99] + + +def test_scanner_roles_to_train_user_and_assistant(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]), + ] + seq = [1, 3, 7, 9, 1, 2, 8, 9] + out = _scan(boundaries, seq, roles_to_train=("user", "assistant")) + # include_start defaults to False so role-start markers stay masked. + assert out == [-100, -100, 7, 9, -100, -100, 8, 9] + + +def test_scanner_truncated_assistant(): + """Missing end marker: span runs to end-of-sequence, end marker not emitted.""" + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + ] + seq = [1, 2, 8, 8, 8] + out = _scan(boundaries, seq) + assert out == [-100, -100, 8, 8, 8] + + +def test_scanner_longest_prefix_wins(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2, 4], end_tokens=[9]), + RoleBoundary(role="user", start_tokens=[1, 2], end_tokens=[9]), + ] + seq = [1, 2, 4, 8, 9] + out = _scan(boundaries, seq) + assert out == [-100, -100, -100, 8, 9] + + +def test_scanner_no_boundaries_masks_everything(): + # Strategies short-circuit this in _mask_non_assistant; see test_base_strategy_warns_when_no_boundaries. + labels = torch.tensor([[1, 2, 3, 4]]) + out = _apply_role_boundaries(labels, [], {"assistant"}, "turn") + assert out.tolist() == [[-100, -100, -100, -100]] + + +def test_scanner_train_on_eos_last_only_final_trainable_turn(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + ] + seq = [1, 2, 5, 9, 1, 2, 6, 9] + out = _scan(boundaries, seq, train_on_eos="last") + # Only the second assistant turn's end marker (index 7) is kept. + assert out == [-100, -100, 5, -100, -100, -100, 6, 9] + + +def test_scanner_train_on_eos_last_no_trainable_turn_is_noop(): + boundaries = [ + RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]), + ] + seq = [1, 3, 5, 9, 1, 3, 6, 9] + out = _scan(boundaries, seq, roles_to_train=("assistant",), train_on_eos="last") + assert out == [-100] * 8 + + +def test_strategy_rejects_unknown_train_on_eos(): + vocab = {"BOA": [50], "EOT": [60]} + with pytest.raises(ValueError, match="train_on_eos"): + ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + train_on_eos="bogus", + role_boundaries_override=[ + {"role": "assistant", "start": "BOA", "end": "EOT"} + ], + ) + + +def test_strategy_accepts_all_supported_train_on_eos_values(): + vocab = {"BOA": [50], "EOT": [60]} + for val in ("turn", "all", "none", "last"): + ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + train_on_eos=val, + role_boundaries_override=[ + {"role": "assistant", "start": "BOA", "end": "EOT"} + ], + ) + + +def test_empty_role_boundaries_override_falls_back_to_builtin(): + """Empty override must fall through to built-ins (opt-in semantics).""" + vocab = { + "<|im_start|>assistant\n": [101, 102, 103], + "<|im_start|>user\n": [101, 106, 103], + "<|im_end|>": [104], + } + strat_empty = Qwen2VLProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[], + ) + strat_default = Qwen2VLProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + ) + # Empty override === no override: both strategies keep the built-in boundaries. + assert strat_empty.role_boundaries == strat_default.role_boundaries + assert len(strat_empty.role_boundaries) > 0 # sanity: built-ins are non-empty + + +def test_sft_dataset_schema_accepts_all_supported_train_on_eos_values(): + """SFTDataset.train_on_eos accepts every value the scanner honors.""" + from axolotl.utils.schemas.datasets import SFTDataset + + for val in ("all", "turn", "last", "none"): + ds = SFTDataset(path="dummy", type="chat_template", train_on_eos=val) + assert ds.train_on_eos == val + + with pytest.raises(ValidationError): + SFTDataset(path="dummy", type="chat_template", train_on_eos="bogus") + + +def test_strategy_init_logs_resolved_masking_config_builtin(axolotl_caplog): + vocab = { + "<|im_start|>assistant\n": [101, 102, 103], + "<|im_start|>user\n": [101, 106, 103], + "<|im_end|>": [104], + } + with axolotl_caplog.at_level(logging.INFO, logger="axolotl.processing_strategies"): + Qwen2VLProcessingStrategy(_Processor(_Tokenizer(vocab, pad_id=0))) + msgs = [r.getMessage() for r in axolotl_caplog.records] + assert any( + "ProcessingStrategy init" in m + and "Qwen2VLProcessingStrategy" in m + and "boundaries_source=built-in" in m + for m in msgs + ) + + +def test_strategy_init_logs_resolved_masking_config_override(axolotl_caplog): + vocab = {"BOA": [50, 51], "EOT": [60]} + with axolotl_caplog.at_level(logging.INFO, logger="axolotl.processing_strategies"): + ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + {"role": "assistant", "start": "BOA", "end": "EOT"}, + ], + ) + msgs = [r.getMessage() for r in axolotl_caplog.records] + # Resolved start/end ids must appear in the log so users can verify what + # was actually matched. + assert any( + "ProcessingStrategy init" in m + and "boundaries_source=override" in m + and "[50, 51]" in m + and "[60]" in m + for m in msgs + ) + + +def test_process_labels_no_warning_when_image_token_id_none(): + """image_token_id=None must not trigger a UserWarning from ``labels == None``.""" + import warnings + + vocab = {"BOA": [50], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[{"role": "assistant", "start": "BOA", "end": "EOT"}], + ) + assert strategy.image_token_id is None + with warnings.catch_warnings(): + warnings.simplefilter("error") + strategy.process_labels(torch.tensor([[1, 50, 2, 3, 60]])) + + +def test_roles_to_train_empty_list_masks_everything(): + """An explicit empty list is distinct from None and disables all roles.""" + vocab = {"BOA": [50], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + roles_to_train=[], + role_boundaries_override=[{"role": "assistant", "start": "BOA", "end": "EOT"}], + ) + assert strategy.roles_to_train == [] + seq = [1, 50, 7, 8, 60, 9] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100] * 6 + + +# --------------------------------------------------------------------------- # +# Qwen2VL / Qwen3.5 +# --------------------------------------------------------------------------- # + + +def _qwen_tokenizer(): + # ChatML-ish with image_pad=200, video_pad=201. + vocab = { + "<|im_start|>assistant\n": [101, 102, 103], + "<|im_start|>user\n": [101, 106, 103], + "<|im_start|>system\n": [101, 105, 103], + "<|im_end|>": [104], + "<|image_pad|>": [200], + "<|video_pad|>": [201], + } + return _Tokenizer(vocab, pad_id=0) + + +def _make_qwen2vl(): + tok = _qwen_tokenizer() + return Qwen2VLProcessingStrategy(_Processor(tok)) + + +def test_qwen2vl_masks_user_keeps_assistant_and_image_pad(): + strategy = _make_qwen2vl() + seq = [ + 101, + 105, + 103, + 77, + 104, + 101, + 106, + 103, + 7, + 104, + 101, + 102, + 103, + 200, + 8, + 104, + ] + labels = strategy.process_labels(torch.tensor([seq])) + out = labels.tolist()[0] + assert out[:10] == [-100] * 10 + assert out[10] == -100 and out[11] == -100 and out[12] == -100 + assert out[13] == -100 # image_pad masked post-scan + assert out[14] == 8 + assert out[15] == 104 + + +def test_qwen3_5_masks_video_pad_too(): + tok = _qwen_tokenizer() + strategy = Qwen3_5ProcessingStrategy(_Processor(tok)) + seq = [101, 102, 103, 201, 8, 104] + labels = strategy.process_labels(torch.tensor([seq])) + assert labels.tolist()[0] == [-100, -100, -100, -100, 8, 104] + + +def test_qwen2vl_train_on_inputs_true_keeps_everything(): + tok = _qwen_tokenizer() + strategy = Qwen2VLProcessingStrategy(_Processor(tok), train_on_inputs=True) + seq = [101, 106, 103, 7, 104, 101, 102, 103, 8, 104] + labels = strategy.process_labels(torch.tensor([seq])) + assert labels.tolist()[0] == seq + + +# --------------------------------------------------------------------------- # +# Gemma3 / Gemma3n +# --------------------------------------------------------------------------- # + + +def _gemma_tokenizer(): + vocab = { + "model\n": [1, 2, 3], + "user\n": [1, 10, 3], + "system\n": [1, 11, 3], + "": [4], + "": [50], # boi_token for Gemma3 + } + tok = _Tokenizer(vocab, pad_id=0) + # boi_token is a direct tokenizer attribute on real Gemma3. + tok.boi_token = "" + return tok + + +def test_gemma3_scanner_plus_soft_image_token(): + strategy = Gemma3ProcessingStrategy(_Processor(_gemma_tokenizer())) + seq = [1, 10, 3, 7, 4, 1, 2, 3, 50, 8, 262144, 4] + labels = strategy.process_labels(torch.tensor([seq])) + # boi(50) and soft-image-token(262144) masked post-scan. + assert labels.tolist()[0] == [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 8, + -100, + 4, + ] + + +def test_gemma3n_masks_image_and_audio_attrs(): + tok = _gemma_tokenizer() + # Gemma3n exposes these as integer attrs on the tokenizer. + tok.image_token_id = 70 + tok.audio_token_id = 71 + tok.boi_token_id = 72 + tok.eoi_token_id = 73 + strategy = Gemma3nProcessingStrategy(_Processor(tok)) + seq = [1, 2, 3, 70, 71, 72, 73, 9, 4] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, -100, -100, -100, -100, -100, 9, 4] + + +# --------------------------------------------------------------------------- # +# Gemma 4 +# --------------------------------------------------------------------------- # + + +class _FakeGemma4Tokenizer(_Tokenizer): + """Mirrors google/gemma-4-E2B-it token layout. Gemma4 role-start markers + include the trailing newline so the boundary matches the jinja template.""" + + VOCAB = { + "<|turn>model\n": [105, 4368, 108], + "<|turn>user\n": [105, 7777, 108], + "<|turn>system\n": [105, 8888, 108], + "": [106], + "<|image|>": [258880], + "<|video|>": [258884], + "<|audio|>": [258881], + "<|image>": [255999], + "": [258882], + "<|audio>": [256000], + "": [258883], + } + + def __init__(self): + # Pass a fresh dict so per-instance mutations (should any future + # code path introduce them) cannot leak across tests via the + # shared class-level VOCAB. + super().__init__( + {token: list(ids) for token, ids in self.VOCAB.items()}, + pad_id=0, + unk_id=3, + ) + + +class _FakeGemma4Processor: + def __init__(self): + self.tokenizer = _FakeGemma4Tokenizer() + self.tokenizer.image_token_id = self.tokenizer.vocab["<|image|>"][0] + self.tokenizer.audio_token_id = self.tokenizer.vocab["<|audio|>"][0] + self.image_token = "<|image|>" + self.image_token_id = self.tokenizer.vocab["<|image|>"][0] + self.boi_token = "<|image>" + self.eoi_token = "" + self.video_token = "<|video|>" + self.video_token_id = self.tokenizer.vocab["<|video|>"][0] + self.audio_token = "<|audio|>" + self.audio_token_id = self.tokenizer.vocab["<|audio|>"][0] + self.boa_token = "<|audio>" + self.eoa_token = "" + + +def test_gemma4_masks_everything_outside_assistant_span(): + strategy = Gemma4ProcessingStrategy(_FakeGemma4Processor()) + V = strategy.processor.tokenizer.vocab + user_start = V["<|turn>user\n"] + model_start = V["<|turn>model\n"] + turn_end = V[""][0] + seq = [ + 0, + *user_start, + 4444, + turn_end, + *model_start, + 5555, + turn_end, + 9999, + ] + labels = strategy.process_labels(torch.tensor([seq])) + expected = [-100] * (1 + len(user_start) + 1 + 1 + len(model_start)) + [ + 5555, + turn_end, + -100, + ] + assert labels.tolist()[0] == expected + + +def test_gemma4_masks_media_tokens_inside_assistant_span(): + strategy = Gemma4ProcessingStrategy(_FakeGemma4Processor()) + V = strategy.processor.tokenizer.vocab + model_start = V["<|turn>model\n"] + media = [ + V["<|image|>"][0], + V["<|video|>"][0], + V["<|audio|>"][0], + V["<|image>"][0], + V[""][0], + V["<|audio>"][0], + V[""][0], + ] + turn_end = V[""][0] + seq = [*model_start, *media, 9999, turn_end] + labels = strategy.process_labels(torch.tensor([seq])) + expected = [-100] * (len(model_start) + len(media)) + [9999, turn_end] + assert labels.tolist()[0] == expected + + +def test_gemma4_multiple_assistant_turns(): + strategy = Gemma4ProcessingStrategy(_FakeGemma4Processor()) + V = strategy.processor.tokenizer.vocab + turn_end = V[""][0] + + def user_turn(x): + return [*V["<|turn>user\n"], x, turn_end] + + def model_turn(x): + return [*V["<|turn>model\n"], x, turn_end] + + seq = user_turn(1111) + model_turn(2222) + user_turn(3333) + model_turn(4444) + labels = strategy.process_labels(torch.tensor([seq])) + kept = [t for t in labels.tolist()[0] if t != -100] + assert kept == [2222, turn_end, 4444, turn_end] + + +# --------------------------------------------------------------------------- # +# Llama 3.2 Vision / Llama 4 +# --------------------------------------------------------------------------- # + + +def test_llama3_2_vision_assistant_masking(): + vocab = { + "<|start_header_id|>assistant<|end_header_id|>\n\n": [1, 2, 3, 4, 5], + "<|start_header_id|>user<|end_header_id|>\n\n": [1, 2, 6, 4, 5], + "<|start_header_id|>system<|end_header_id|>\n\n": [1, 2, 7, 4, 5], + "<|start_header_id|>tool<|end_header_id|>\n\n": [1, 2, 8, 4, 5], + "<|start_header_id|>ipython<|end_header_id|>\n\n": [1, 2, 9, 4, 5], + "<|eot_id|>": [10], + } + strategy = Llama3_2VisionProcessingStrategy(_Processor(_Tokenizer(vocab, pad_id=0))) + seq = [1, 2, 6, 4, 5, 11, 10, 1, 2, 3, 4, 5, 12, 10] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100] * 12 + [12, 10] + + +def test_llama4_assistant_masking(): + vocab = { + "<|header_start|>assistant<|header_end|>\n\n": [20, 21, 22, 23], + "<|header_start|>user<|header_end|>\n\n": [20, 21, 24, 23], + "<|header_start|>system<|header_end|>\n\n": [20, 21, 25, 23], + "<|header_start|>tool<|header_end|>\n\n": [20, 21, 26, 23], + "<|header_start|>ipython<|header_end|>\n\n": [20, 21, 27, 23], + "<|eot|>": [30], + } + strategy = Llama4ProcessingStrategy(_Processor(_Tokenizer(vocab, pad_id=0))) + seq = [20, 21, 24, 23, 100, 30, 20, 21, 22, 23, 200, 30] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100] * 10 + [200, 30] + + +# --------------------------------------------------------------------------- # +# Pixtral / Mistral v7 Tekken (eos-terminated assistant) +# --------------------------------------------------------------------------- # + + +def test_pixtral_assistant_terminates_at_eos(): + # [/INST] is both user-end and assistant-start. Scanner backs up when + # user.include_end=False so the next iteration picks [/INST] up as + # assistant-start (Pixtral-specific handling in _build_role_boundaries). + vocab = { + "[INST]": [50], + "[/INST]": [51], + } + tok = _Tokenizer(vocab, pad_id=0, eos_id=99) + strategy = PixtralProcessingStrategy(_Processor(tok)) + seq = [50, 7, 51, 8, 8, 99] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + # Full-sequence expectation: user span masked; assistant content + eos kept. + assert out == [-100, -100, -100, 8, 8, 99] + + +def test_mistral_v7_tekken_system_user_assistant(): + vocab = { + "[SYSTEM_PROMPT]": [40], + "[/SYSTEM_PROMPT]": [41], + "[INST]": [50], + "[/INST]": [51], + } + tok = _Tokenizer(vocab, pad_id=0, eos_id=99) + strategy = MistralV7TekkenProcessingStrategy(_Processor(tok)) + seq = [40, 5, 41, 50, 7, 51, 8, 99] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + # Full-sequence expectation: system + user spans masked; assistant kept. + assert out == [-100, -100, -100, -100, -100, -100, 8, 99] + + +def test_pixtral_train_on_eos_all_respects_user_include_end_false(): + """Pixtral [/INST] (user-end include_end=False) stays masked on 'all'.""" + vocab = {"[INST]": [50], "[/INST]": [51]} + tok = _Tokenizer(vocab, pad_id=0, eos_id=99) + strategy = PixtralProcessingStrategy(_Processor(tok), train_on_eos="all") + seq = [50, 7, 51, 8, 8, 99] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + # [/INST] at idx 2 must stay masked — user.include_end=False says so. + # Assistant content (8, 8) + EOS (99) are unmasked as normal. + assert out == [-100, -100, -100, 8, 8, 99] + + +def test_mistral_v7_tekken_train_on_eos_all_respects_user_include_end_false(): + """System end (include_end=True) unmasked on 'all'; [/INST] stays masked.""" + vocab = { + "[SYSTEM_PROMPT]": [40], + "[/SYSTEM_PROMPT]": [41], + "[INST]": [50], + "[/INST]": [51], + } + tok = _Tokenizer(vocab, pad_id=0, eos_id=99) + strategy = MistralV7TekkenProcessingStrategy(_Processor(tok), train_on_eos="all") + seq = [40, 5, 41, 50, 7, 51, 8, 99] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + # system content masked, [/SYSTEM_PROMPT]=41 kept (include_end=True + all); + # user + [/INST]=51 masked (include_end=False); assistant 8 + eos 99 kept. + assert out == [-100, -100, 41, -100, -100, -100, 8, 99] + + +# --------------------------------------------------------------------------- # +# Dispatcher routing +# --------------------------------------------------------------------------- # + + +def _dispatch(processor, chat_template_type): + return get_processing_strategy( + processor=processor, + chat_template=None, + chat_template_type=chat_template_type, + ) + + +def test_dispatch_qwen2_vl(): + s = _dispatch(_Processor(_qwen_tokenizer()), "qwen2_vl") + assert isinstance(s, Qwen2VLProcessingStrategy) + + +def test_dispatch_qwen3_5(): + s = _dispatch(_Processor(_qwen_tokenizer()), "qwen3_5") + assert isinstance(s, Qwen3_5ProcessingStrategy) + + +def test_dispatch_gemma3(): + s = _dispatch(_Processor(_gemma_tokenizer()), "gemma3") + assert isinstance(s, Gemma3ProcessingStrategy) + + +def test_dispatch_gemma3n(): + s = _dispatch(_Processor(_gemma_tokenizer()), "gemma3n") + assert isinstance(s, Gemma3nProcessingStrategy) + + +def test_dispatch_gemma4(): + s = _dispatch(_FakeGemma4Processor(), "gemma4") + assert isinstance(s, Gemma4ProcessingStrategy) + + +def test_dispatch_llama3_2_vision(): + vocab = { + "<|start_header_id|>assistant<|end_header_id|>\n\n": [1, 2, 3, 4, 5], + "<|eot_id|>": [10], + } + s = _dispatch(_Processor(_Tokenizer(vocab, pad_id=0)), "llama3_2_vision") + assert isinstance(s, Llama3_2VisionProcessingStrategy) + + +def test_dispatch_llama4(): + vocab = { + "<|header_start|>assistant<|header_end|>\n\n": [20, 21, 22, 23], + "<|eot|>": [30], + } + s = _dispatch(_Processor(_Tokenizer(vocab, pad_id=0)), "llama4") + assert isinstance(s, Llama4ProcessingStrategy) + + +def test_dispatch_pixtral(): + vocab = {"[INST]": [50], "[/INST]": [51]} + s = _dispatch(_Processor(_Tokenizer(vocab, pad_id=0, eos_id=99)), "pixtral") + assert isinstance(s, PixtralProcessingStrategy) + + +def test_dispatch_mistral_v7_tekken(): + vocab = { + "[INST]": [50], + "[/INST]": [51], + "[SYSTEM_PROMPT]": [40], + "[/SYSTEM_PROMPT]": [41], + } + s = _dispatch( + _Processor(_Tokenizer(vocab, pad_id=0, eos_id=99)), "mistral_v7_tekken" + ) + assert isinstance(s, MistralV7TekkenProcessingStrategy) + + +def test_dispatch_unknown_falls_back_to_base(): + vocab = {"dummy": [1]} + s = _dispatch(_Processor(_Tokenizer(vocab, pad_id=0)), "llava") + assert type(s) is ProcessingStrategy + + +def _glm_vision_processor(cls_path): + """Spec'd MagicMock so isinstance(mock, cls) passes without real HF files.""" + from importlib import import_module + from unittest.mock import MagicMock + + mod_name, cls_name = cls_path.rsplit(".", 1) + cls = getattr(import_module(mod_name), cls_name) + + vocab = { + "<|image|>": [200], + "<|begin_of_image|>": [201], + "<|end_of_image|>": [202], + "<|video|>": [210], + "<|begin_of_video|>": [211], + "<|end_of_video|>": [212], + } + tok = _Tokenizer(vocab, pad_id=0) + proc = MagicMock(spec=cls) + proc.tokenizer = tok + # Drop processor.image_token so base class skips its probe. + del proc.image_token + return proc + + +def test_dispatch_glm4v_via_Glm4vProcessor(): + """Glm4vProcessor (GLM-4V) routes to Glm4vProcessingStrategy.""" + pytest.importorskip("transformers.models.glm4v.processing_glm4v") + from axolotl.processing_strategies import Glm4vProcessingStrategy + + proc = _glm_vision_processor( + "transformers.models.glm4v.processing_glm4v.Glm4vProcessor" + ) + s = _dispatch(proc, None) + assert isinstance(s, Glm4vProcessingStrategy) + + +def test_dispatch_glm4v_via_Glm46VProcessor(): + """Glm46VProcessor (GLM-4.6V) also routes to Glm4vProcessingStrategy.""" + pytest.importorskip("transformers.models.glm46v.processing_glm46v") + from axolotl.processing_strategies import Glm4vProcessingStrategy + + proc = _glm_vision_processor( + "transformers.models.glm46v.processing_glm46v.Glm46VProcessor" + ) + s = _dispatch(proc, None) + assert isinstance(s, Glm4vProcessingStrategy) + + +# --------------------------------------------------------------------------- # +# Config-based role-boundary override +# --------------------------------------------------------------------------- # + + +def test_role_boundaries_override_replaces_built_in(): + """Override swaps the built-in boundaries wholesale, not additively.""" + vocab = { + "<|im_start|>assistant\n": [101, 102, 103], + "<|im_start|>user\n": [101, 106, 103], + "<|im_end|>": [104], + ">>>A": [200, 201], + ">>>U": [200, 202], + "<<<": [210], + "<|image_pad|>": [250], + } + strategy = Qwen2VLProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + {"role": "assistant", "start": ">>>A", "end": "<<<"}, + {"role": "user", "start": ">>>U", "end": "<<<"}, + ], + ) + seq = [ + 101, + 106, + 103, + 7, + 104, + 200, + 201, + 9, + 9, + 210, + ] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, -100, -100, -100, -100, -100, 9, 9, 210] + + +def test_role_boundaries_override_enables_unverified_strategy(): + """Override lets users opt in to role masking on strategies that default opt out.""" + vocab = { + "BOA": [50, 51], + "EOT": [60], + } + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + {"role": "assistant", "start": "BOA", "end": "EOT"}, + ], + ) + seq = [1, 2, 3, 50, 51, 7, 8, 60, 9] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, -100, -100, -100, 7, 8, 60, -100] + + +def test_role_boundaries_override_eos_token_sentinel(): + vocab = {"BOA": [50]} + tok = _Tokenizer(vocab, pad_id=0, eos_id=99) + strategy = ProcessingStrategy( + _Processor(tok), + role_boundaries_override=[ + {"role": "assistant", "start": "BOA", "end": "eos_token"}, + ], + ) + seq = [1, 50, 7, 7, 99, 2] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, 7, 7, 99, -100] + + +def test_role_boundaries_override_end_null_runs_to_sequence_end(): + vocab = {"BOA": [50]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + {"role": "assistant", "start": "BOA", "end": None}, + ], + ) + seq = [1, 2, 50, 7, 8, 9] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, -100, 7, 8, 9] + + +def test_role_boundaries_override_rejects_bad_spec(): + vocab = {"BOA": [50]} + with pytest.raises(ValueError, match="must have both 'role' and 'start'"): + ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[{"role": "assistant"}], + ) + + +def test_role_boundaries_override_rejects_unencodable_start(): + vocab = {"BOA": [50]} + with pytest.raises(ValueError, match="tokenizes to an empty sequence"): + ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + {"role": "assistant", "start": "MISSING", "end": None} + ], + ) + + +def test_role_boundaries_override_rejects_unencodable_end(): + vocab = {"BOA": [50]} + with pytest.raises(ValueError, match="tokenizes to an empty sequence"): + ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + {"role": "assistant", "start": "BOA", "end": "MISSING"} + ], + ) + + +def test_role_boundaries_override_accepts_pydantic_models(): + # cfg.role_boundaries arrives as RoleBoundarySpec after pydantic parsing. + from axolotl.utils.schemas.multimodal import RoleBoundarySpec + + vocab = {"BOA": [50], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + RoleBoundarySpec(role="assistant", start="BOA", end="EOT") + ], + ) + assert len(strategy.role_boundaries) == 1 + assert strategy.role_boundaries[0].role == "assistant" + assert strategy.role_boundaries[0].start_tokens == [50] + assert strategy.role_boundaries[0].end_tokens == [60] + + +def test_base_strategy_warns_when_no_boundaries(axolotl_caplog): + """No boundaries + train_on_inputs=False: one-shot warning, labels unchanged.""" + import axolotl.processing_strategies as mod + + mod._ROLE_MASK_WARNED.discard("ProcessingStrategy") + + vocab = {"dummy": [1]} + s = ProcessingStrategy(_Processor(_Tokenizer(vocab, pad_id=0))) + + with axolotl_caplog.at_level( + logging.WARNING, logger="axolotl.processing_strategies" + ): + labels = s.process_labels(torch.tensor([[1, 2, 3]])) + assert labels.tolist() == [[1, 2, 3]] + assert any("role boundaries" in rec.message for rec in axolotl_caplog.records) + + +# --------------------------------------------------------------------------- # +# Additional edge-case coverage +# --------------------------------------------------------------------------- # + + +def test_scanner_batch_size_greater_than_one(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]), + ] + labels = torch.tensor( + [ + [1, 3, 7, 9, 1, 2, 8, 9], + [1, 2, 5, 5, 9, 0, 0, 0], + ] + ) + out = _apply_role_boundaries(labels, boundaries, {"assistant"}, "turn").tolist() + assert out[0] == [-100, -100, -100, -100, -100, -100, 8, 9] + assert out[1] == [-100, -100, 5, 5, 9, -100, -100, -100] + + +def test_scanner_adjacent_trainable_turns(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + ] + seq = [1, 2, 5, 9, 1, 2, 6, 9] + out = _scan(boundaries, seq) + assert out == [-100, -100, 5, 9, -100, -100, 6, 9] + + +def test_scanner_train_on_eos_none_multi_turn(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]), + ] + seq = [1, 3, 7, 9, 1, 2, 8, 9, 1, 3, 7, 9, 1, 2, 6, 9] + out = _scan(boundaries, seq, train_on_eos="none") + assert out == [ + -100, + -100, + -100, + -100, + -100, + -100, + 8, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 6, + -100, + ] + + +def test_scanner_train_on_eos_all_with_user_turn_no_end_marker(): + """Unclosed non-trainable span with train_on_eos='all': nothing included, no crash.""" + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]), + ] + seq = [1, 3, 7, 7, 7] + out = _scan(boundaries, seq, train_on_eos="all") + assert out == [-100, -100, -100, -100, -100] + + +def test_scanner_include_start_true_via_override(): + vocab = {"BOA": [50, 51], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + { + "role": "assistant", + "start": "BOA", + "end": "EOT", + "include_start": True, + }, + ], + ) + seq = [1, 50, 51, 7, 8, 60, 9] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, 50, 51, 7, 8, 60, -100] + + +def test_scanner_include_end_false_via_override(): + """include_end=False drops end marker even with train_on_eos='turn'.""" + vocab = {"BOA": [50], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + { + "role": "assistant", + "start": "BOA", + "end": "EOT", + "include_end": False, + }, + ], + ) + seq = [1, 50, 7, 8, 60, 9] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, 7, 8, -100, -100] + + +def test_scanner_empty_start_tokens_is_defensive_noop(): + """Defensive: empty start_tokens matches nothing; everything masked.""" + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[], end_tokens=[9]), + ] + seq = [1, 2, 3, 4, 9] + out = _scan(boundaries, seq) + assert out == [-100] * 5 + + +def test_process_labels_masks_pad_inside_assistant_span(): + """Pad inside a trainable span is still masked post-scan.""" + strategy = _make_qwen2vl() + seq = [101, 102, 103, 8, 0, 8, 104] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, -100, 8, -100, 8, 104] + + +def test_process_labels_all_pad_sequence_does_not_crash(): + strategy = _make_qwen2vl() + seq = [0, 0, 0, 0] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, -100, -100] + + +def test_qwen2vl_multiple_consecutive_assistant_turns(): + strategy = _make_qwen2vl() + seq = [101, 102, 103, 8, 104, 101, 102, 103, 9, 104] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [ + -100, + -100, + -100, + 8, + 104, + -100, + -100, + -100, + 9, + 104, + ] + + +def test_qwen2vl_batch_of_two_rows(): + strategy = _make_qwen2vl() + row_a = [101, 106, 103, 7, 104, 101, 102, 103, 8, 104] + row_b = [101, 102, 103, 9, 104, 0, 0, 0, 0, 0] + out = strategy.process_labels(torch.tensor([row_a, row_b])).tolist() + assert out[0] == [-100, -100, -100, -100, -100, -100, -100, -100, 8, 104] + assert out[1] == [-100, -100, -100, 9, 104, -100, -100, -100, -100, -100] + + +def test_qwen3_5_train_on_inputs_true_still_masks_video_pad(): + """train_on_inputs=True skips role masking but media tokens are still masked.""" + tok = _qwen_tokenizer() + strategy = Qwen3_5ProcessingStrategy(_Processor(tok), train_on_inputs=True) + seq = [101, 106, 103, 201, 7, 104, 101, 102, 103, 201, 8, 104] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + expected = list(seq) + expected[3] = -100 + expected[9] = -100 + assert out == expected + + +def test_role_boundaries_override_role_not_in_roles_to_train(): + """Override covering only a non-trainable role masks everything.""" + vocab = {"BOU": [50], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + {"role": "user", "start": "BOU", "end": "EOT"}, + ], + ) + seq = [1, 50, 7, 8, 60, 9] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100] * 6 + + +def test_role_boundaries_override_include_start_flag_round_trips(): + from axolotl.utils.schemas.multimodal import RoleBoundarySpec + + vocab = {"BOA": [50], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + RoleBoundarySpec( + role="assistant", start="BOA", end="EOT", include_start=True + ), + ], + ) + assert len(strategy.role_boundaries) == 1 + assert strategy.role_boundaries[0].include_start is True + assert strategy.role_boundaries[0].include_end is True + + +def test_multimodal_config_parses_dict_role_boundaries_to_specs(): + from axolotl.utils.schemas.multimodal import ( + MultiModalConfig, + RoleBoundarySpec, + ) + + cfg = MultiModalConfig( + role_boundaries=[ + {"role": "assistant", "start": "BOA", "end": "EOT"}, + {"role": "user", "start": "BOU", "end": "EOT"}, + ] + ) + assert cfg.role_boundaries is not None + assert len(cfg.role_boundaries) == 2 + assert all(isinstance(rb, RoleBoundarySpec) for rb in cfg.role_boundaries) + + vocab = {"BOA": [50], "BOU": [51], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=cfg.role_boundaries, + ) + seq = [51, 7, 60, 50, 8, 60] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, -100, -100, 8, 60]