diff --git a/docs/multimodal_assistant_mask.md b/docs/multimodal_assistant_mask.md new file mode 100644 index 000000000..339ab420f --- /dev/null +++ b/docs/multimodal_assistant_mask.md @@ -0,0 +1,84 @@ +# Multimodal assistant-only loss masking + +## Correct placement + +```yaml +# Top-level: only train_on_inputs lives here. +train_on_inputs: false + +datasets: + - path: data/train.jsonl + type: chat_template + roles_to_train: # per-dataset — this is what the MM scanner reads + - assistant + train_on_eos: turn # per-dataset — same + +test_datasets: + - path: data/val.jsonl + type: chat_template + split: train + roles_to_train: + - assistant + train_on_eos: turn +``` + +## How to verify at runtime + +`build_collator` logs the resolved knobs at INFO: + +```text +MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none +``` + +If `roles_to_train` logs as `None`, the YAML knobs are not reaching the +scanner — check that they are under `datasets[0]`, not at the root. + +Each verified strategy additionally logs its resolved boundary token ids at +strategy init (e.g. `<|turn>model` → `[105, 4368]`, `` → `[106]` for +Gemma 4). If a strategy emits the "has no built-in role boundaries ... only +pad and media tokens are masked" one-shot warning instead, it is on the +fallback path — declare per-role markers in YAML via `cfg.role_boundaries` +(below) to activate masking. The strategies currently on this path are +listed in the audit table above under `fallback + warn`. + +## Config-based override: `cfg.role_boundaries` + +For the "unverified" strategies above, or for custom chat templates that +don't match a built-in strategy's markers, users can declare role boundaries +directly in YAML without subclassing: + +```yaml +role_boundaries: + - role: assistant + start: "<|turn>model" + end: "" + - role: user + start: "<|turn>user" + end: "" + # Optional keys: + # include_start: false # default False + # include_end: true # default True, respects cfg.train_on_eos + # end: eos_token # sentinel: resolves to tokenizer.eos_token_id + # end: null # span runs to end of sequence +``` + +Semantics: + +- `start` and `end` are literal strings; axolotl encodes them at strategy + init via `tokenizer.encode(..., add_special_tokens=False)` and logs the + resolved token-id sequences at INFO level. +- The special value `end: eos_token` is the portable way to express + "Pixtral-style assistant turns end at EOS" without hard-coding an id. +- `role_boundaries` is an **opt-in override**. A non-empty list **replaces** + the strategy's built-in declarations wholesale (partial overlays are + intentionally unsupported — they're hard to reason about at review time). + Leaving the field unset *or* setting it to an empty list (`[]`) both mean + "use the strategy's built-ins." Writing `role_boundaries: []` is almost + always a typo or leftover — honoring it literally would produce all-masked + labels and zero gradient, so it is treated the same as unset. +- `cfg.roles_to_train` still governs which declared roles contribute to + loss. You can declare `user` and `assistant` boundaries and set + `roles_to_train: ["assistant"]` to have the scanner correctly identify + user spans as masking boundaries without training on their content. +- Invalid specs fail loudly at strategy init (missing `role`/`start`, + unencodable markers), not silently at loss-compute time. diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 15624173d..aa1678523 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -515,12 +515,53 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): else: if self.cfg.processor_type and self.processor: collator = MultiModalChatDataCollator + # Mirror ChatTemplateStrategy: per-dataset masking knobs from first MM dataset, else global cfg. + # NOTE: Multi-dataset configs use the first dataset's masking knobs for all datasets; + # heterogeneous per-dataset overrides are not supported in the MM path today. + ds_entries = self.cfg.datasets or [] + ds_cfg = ds_entries[0] if ds_entries else None + + def _ds_get(cfg_obj, key): + # Handle DictDefault / dict / pydantic uniformly: + # dict-style .get first, then attribute access. + if cfg_obj is None: + return None + if hasattr(cfg_obj, "get"): + try: + return cfg_obj.get(key) + except (AttributeError, KeyError, TypeError): + pass + return getattr(cfg_obj, key, None) + + roles_to_train = _ds_get(ds_cfg, "roles_to_train") + train_on_eos = _ds_get(ds_cfg, "train_on_eos") + + # cfg.role_boundaries replaces the strategy's built-in markers. + role_boundaries_override = None + if self.cfg.role_boundaries: + role_boundaries_override = list(self.cfg.role_boundaries) + + # build() calls build_collator twice (eval + train); log once. + if not is_eval: + LOG.info( + "MM collator: train_on_inputs=%s roles_to_train=%s " + "train_on_eos=%s role_boundaries_override=%s", + bool(self.cfg.train_on_inputs), + roles_to_train, + train_on_eos, + "set" if role_boundaries_override else "none", + ) + kwargs["processing_strategy"] = get_processing_strategy( self.processor, training_args.chat_template, self.cfg.chat_template, image_size=training_args.image_size, image_resize_algorithm=training_args.image_resize_algorithm, + train_on_inputs=bool(self.cfg.train_on_inputs), + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, ) elif self.cfg.batch_flattening: collator = DataCollatorWithFlattening diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index cb1f9d984..217bc765b 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -1,6 +1,7 @@ """Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types""" from copy import deepcopy +from dataclasses import dataclass, field from typing import Optional from PIL import Image, ImageOps @@ -17,9 +18,35 @@ from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +# One-shot warning dedupe so opt-out subclasses don't spam per-batch. +_ROLE_MASK_WARNED: set[str] = set() + +# Supported values for ``train_on_eos`` — mirrors the text-only +# ChatTemplateStrategy (``turn`` = trainable turn ends only, ``all`` = every +# turn end, ``none`` = never, ``last`` = only the final trainable turn end). +_VALID_TRAIN_ON_EOS = ("turn", "all", "none", "last") + + +@dataclass(frozen=True) +class RoleBoundary: + """One role's token-level span markers for the masking scanner. + + Empty ``end_tokens`` means end-of-sequence terminates the span. + """ + + role: str + start_tokens: list[int] + end_tokens: list[int] = field(default_factory=list) + include_start: bool = False + include_end: bool = True + class ProcessingStrategy: - """Base Processing Strategy class""" + """Base Processing Strategy class. + + Subclasses opt in to role masking by overriding ``_build_role_boundaries``; + otherwise only pad + media tokens are masked (legacy behavior, one-shot warned). + """ def __init__( self, @@ -27,6 +54,10 @@ class ProcessingStrategy: chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): self.processor = processor self.chat_template = chat_template @@ -38,54 +69,97 @@ class ProcessingStrategy: image_resize_algorithm or Image.Resampling.BILINEAR ) + # Defaults mirror the text-only ChatTemplateStrategy. An explicit + # empty list is honored as "no trainable roles" (masks everything); + # only ``None`` falls back to the default of assistant-only. + self.train_on_inputs = bool(train_on_inputs) + self.roles_to_train = ( + list(roles_to_train) if roles_to_train is not None else ["assistant"] + ) + self.train_on_eos = train_on_eos if train_on_eos is not None else "turn" + if self.train_on_eos not in _VALID_TRAIN_ON_EOS: + raise ValueError( + f"train_on_eos={self.train_on_eos!r} is not one of " + f"{_VALID_TRAIN_ON_EOS}." + ) + if hasattr(processor, "image_token"): self.image_token = processor.image_token self.image_token_id = processor.tokenizer.convert_tokens_to_ids( self.image_token ) + built_in = self._build_role_boundaries() + + # Truthiness check: empty list == unset (opt-in escape hatch), so + # `role_boundaries: []` in YAML falls through to built-ins instead of + # producing all-masked labels. + if role_boundaries_override: + overridden = _resolve_role_boundary_override( + role_boundaries_override, self.processor.tokenizer + ) + LOG.info( + "%s: overriding built-in role boundaries (%d decls) " + "with cfg.role_boundaries (%d decls).", + type(self).__name__, + len(built_in), + len(overridden), + ) + self.role_boundaries: list[RoleBoundary] = overridden + source = "override" + else: + self.role_boundaries = built_in + source = "built-in" + + # Single-line, grep-friendly summary of the resolved masking config so + # "why isn't masking firing?" is visible in training logs. For + # overrides we include the fully resolved (role, start_ids, end_ids) + # tuples; for built-ins we log a count (subclasses vary and logging + # every id sequence would be noisy on, e.g., Llama3 with five roles). + boundaries_repr: str | list[tuple[str, list[int], list[int]]] + if source == "override": + boundaries_repr = [ + (b.role, b.start_tokens, b.end_tokens) for b in self.role_boundaries + ] + else: + boundaries_repr = f"{len(self.role_boundaries)} built-in" + LOG.info( + "ProcessingStrategy init: class=%s train_on_inputs=%s " + "roles_to_train=%s train_on_eos=%s boundaries_source=%s " + "boundaries=%s", + type(self).__name__, + self.train_on_inputs, + self.roles_to_train, + self.train_on_eos, + source, + boundaries_repr, + ) + + def _build_role_boundaries(self) -> list[RoleBoundary]: + """Subclasses declare role boundaries here; [] opts out of role masking.""" + return [] + def __call__(self, examples: list[dict]) -> list[dict]: - """ - Preprocess conversation examples to ensure consistent format. - Converts different conversation formats to OpenAI format with 'messages'. - Supports two formats: - 1. OpenAI format with 'messages' - 2. Legacy format with 'conversations' - - Args: - examples: list of conversation dictionaries - - Returns: - list of dicts in OpenAI format with 'messages' key - - Raises: - ValueError: If the conversation format is not supported - """ + """Normalize examples to OpenAI ``messages`` format (accepts legacy ``conversations``).""" role_mapping = { "human": "user", "gpt": "assistant", } def normalize_role(role: str) -> str: - """Normalize role names to OpenAI format. Default to original role if not found.""" return role_mapping.get(role, role) def convert_legacy_format(example: dict) -> dict: - """Convert legacy 'conversations' format to OpenAI 'messages' format.""" messages = [ {"role": normalize_role(convo["from"]), "content": convo["value"]} for convo in example["conversations"] ] - - # Create new dict without 'conversations' key result = deepcopy(example) result.pop("conversations") result["messages"] = messages return result def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]: - """Convert regular messages format to Messages format with content type""" - new_messages = [] for message in messages: if isinstance(message["content"], str): @@ -119,21 +193,27 @@ class ProcessingStrategy: "Only `messages` and `conversations` message keys are currently supported." ) - processed_example = None - if ( - "messages" in example and example["messages"] is not None - ): # OpenAI format - processed_example = example - else: # Legacy format + if "messages" in example and example["messages"] is not None: + # Deepcopy for symmetry with convert_legacy_format (which + # deepcopies internally) so downstream mutations of + # processed_example don't leak back to the caller's input. + processed_example = deepcopy(example) + elif "conversations" in example: processed_example = convert_legacy_format(example) + else: + # `messages` is present but None, and no `conversations` + # fallback exists — convert_legacy_format would KeyError on + # ["conversations"]. Surface a clear validation error instead. + raise ValueError( + "`messages` is present but None; provide non-null " + "`messages` or a `conversations` field." + ) - # convert regular messages format to Messages format with content type - # for compatibility with apply_chat_template + # Required for apply_chat_template compatibility. processed_example["messages"] = convert_messages_to_multimedia_messages( processed_example["messages"] ) - # find the image key if it exists possible_image_keys = ["images", "image"] image_key = None for key in possible_image_keys: @@ -141,11 +221,8 @@ class ProcessingStrategy: image_key = key break - # if the image key exists, add the image to the first user message if image_key is not None and processed_example[image_key] is not None: - # TODO: check if it's normal to be single image only for common datasets - # From observation, it's usually a list of single image but some datasets may have several columns for images - # Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages + # TODO: support multi-image samples; for now we take the first. if len(processed_example[image_key]) > 1: LOG.warning( f"Found {len(processed_example[image_key])} images in a sample. Using the first one." @@ -155,7 +232,6 @@ class ProcessingStrategy: image_value = processed_example[image_key][0] - # Handle image loading (Image, url, path, base64) image_value = load_image(image_value) if self.image_size is not None: @@ -168,11 +244,8 @@ class ProcessingStrategy: self.image_size, self.image_resize_algorithm ) else: - # Set the padding value; here we use black (0, 0, 0) for RGB images + # Int image_size: preserve aspect ratio then pad to square (black) to avoid distortion. padding_color = (0, 0, 0) - - # When image_size is an int (square target), preserve aspect ratio then pad - # This is to prevent aspect ratio distortion when resizing to square image_value = ImageOps.pad( image_value, (self.image_size, self.image_size), @@ -180,8 +253,6 @@ class ProcessingStrategy: color=padding_color, ) - # Look for any image type in the first message - # some dataset have an {type: "image"} in the first message msg_ind_to_add = None ind_to_add = None first_user_idx = None @@ -192,7 +263,7 @@ class ProcessingStrategy: for i, content in enumerate( processed_example["messages"][msg_idx]["content"] ): - # Usually datasets created with image columns, don't have it in the messages itself + # Column-image datasets often leave a bare {type: "image"} placeholder. if content["type"] == "image" and all( k not in content for k in ["image", "url", "path", "base64"] ): @@ -200,13 +271,11 @@ class ProcessingStrategy: ind_to_add = i break - # If an image type is found, add the image to that index if ind_to_add is not None and msg_ind_to_add is not None: processed_example["messages"][msg_ind_to_add]["content"][ ind_to_add ]["image"] = image_value else: - # if no image type is found, add it to end of the first user message if first_user_idx is None: first_user_idx = 0 processed_example["messages"][first_user_idx]["content"].append( @@ -221,28 +290,224 @@ class ProcessingStrategy: return processed_examples def _mask_non_assistant(self, labels: Tensor) -> Tensor: - """ - Mask non assistant regions to -100. - To be implemented per subclass. - """ - return labels + """Mask non-trainable role regions to -100 using ``self.role_boundaries``.""" + if self.train_on_inputs: + return labels + + # Legacy no-op for boundary-less strategies; warn once so the miss shows up in logs. + if not self.role_boundaries: + key = type(self).__name__ + if key not in _ROLE_MASK_WARNED: + _ROLE_MASK_WARNED.add(key) + LOG.warning( + "%s has no built-in role boundaries; " + "cfg.train_on_inputs / cfg.roles_to_train / cfg.train_on_eos " + "will NOT restrict loss to assistant tokens for this " + "multimodal model — only pad and media tokens are masked, " + "every other token (system, user, assistant) contributes " + "to loss. To enable assistant-only masking, declare " + "per-role markers in YAML via cfg.role_boundaries — see " + "docs/multimodal_assistant_mask.md for the format and the " + "list of strategies on this fallback path.", + key, + ) + return labels + + return _apply_role_boundaries( + labels, + self.role_boundaries, + roles_to_train=set(self.roles_to_train), + train_on_eos=self.train_on_eos, + ) def process_labels(self, input_ids: Tensor) -> Tensor: labels = input_ids.clone() - labels = self._mask_non_assistant(labels) - - # The labels are the input_ids, and we mask the padding tokens in the loss computation - labels[labels == self.processor.tokenizer.pad_token_id] = -100 - - # Ignore the image token index in the loss computation (model specific) - labels[labels == self.image_token_id] = -100 - + pad_id = getattr(self.processor.tokenizer, "pad_token_id", None) + if pad_id is not None: + labels[labels == pad_id] = -100 + if self.image_token_id is not None: + labels[labels == self.image_token_id] = -100 return labels +def _apply_role_boundaries( + labels: Tensor, + role_boundaries: list[RoleBoundary], + roles_to_train: set[str], + train_on_eos: str, +) -> Tensor: + """Mask tokens outside trainable role spans to -100. + + Scan is greedy-left with longest-prefix-wins on start_tokens to disambiguate + nested markers (e.g. ``<|im_start|>assistant`` vs ``<|im_start|>``). + ``train_on_eos`` accepts ``"turn"`` (end marker in loss on trainable turns + only), ``"all"`` (always), ``"none"`` (never — overrides ``include_end``), + ``"last"`` (only on the last trainable turn in the sequence). + """ + mask = zeros_like(labels) + # For "last": remember each trainable turn's end-marker span so we can + # unmask only the final one after the scan finishes. + last_trainable_end_span: list[Optional[tuple[int, int]]] = [None] * labels.shape[0] + + # Work on a Python list per row — avoids O(n*boundaries) Tensor→list + # conversions in the hot prefix-match loop. + def _match_prefix(label: list[int], start_pos: int, tok_seq: list[int]) -> bool: + if not tok_seq or start_pos + len(tok_seq) > len(label): + return False + return label[start_pos : start_pos + len(tok_seq)] == tok_seq + + def _find_end( + label: list[int], start_pos: int, end_tok: list[int] + ) -> tuple[int, bool]: + # Empty end_tok means run to end-of-sequence. + if not end_tok: + return len(label), False + k = start_pos + while k < len(label): + if _match_prefix(label, k, end_tok): + return k + len(end_tok), True + k += 1 + return k, False + + for i in range(labels.shape[0]): + label = labels[i].tolist() + j = 0 + n = len(label) + while j < n: + best_match: Optional[RoleBoundary] = None + for b in role_boundaries: + if _match_prefix(label, j, b.start_tokens): + if best_match is None or len(b.start_tokens) > len( + best_match.start_tokens + ): + best_match = b + if best_match is None: + j += 1 + continue + + start_of_content = j + len(best_match.start_tokens) + end_after, found_end = _find_end( + label, start_of_content, best_match.end_tokens + ) + + role_in_loss = best_match.role in roles_to_train + + if role_in_loss: + if best_match.include_start: + mask[i][j:start_of_content] = 1 + content_end = ( + end_after - len(best_match.end_tokens) if found_end else end_after + ) + mask[i][start_of_content:content_end] = 1 + # train_on_eos="none"/"last" override include_end during main + # loop; "last" is applied after the scan finishes. + if ( + found_end + and best_match.include_end + and train_on_eos not in ("none", "last") + ): + mask[i][content_end:end_after] = 1 + if found_end and best_match.include_end and train_on_eos == "last": + last_trainable_end_span[i] = (content_end, end_after) + else: + # Non-trainable role on train_on_eos="all": gate on include_end + # so Pixtral / Mistral V7 Tekken shared [/INST] doesn't leak. + if found_end and best_match.include_end and train_on_eos == "all": + content_end = end_after - len(best_match.end_tokens) + mask[i][content_end:end_after] = 1 + + # When include_end=False, do not consume the end marker: back up so + # the next iteration can re-match it as the next boundary's start + # marker (Pixtral / Mistral V7 Tekken share [/INST] between + # user-end and assistant-start). Requires end_tokens non-empty and + # actually found. + if found_end and not best_match.include_end and best_match.end_tokens: + j = end_after - len(best_match.end_tokens) + else: + j = end_after + + if train_on_eos == "last" and (span := last_trainable_end_span[i]) is not None: + s, e = span + mask[i][s:e] = 1 + + labels[i][mask[i] == 0] = -100 + + return labels + + +def _encode_markers(tokenizer, marker_strs: list[str]) -> list[list[int]]: + """Encode markers via ``encode(..., add_special_tokens=False)``; drops empty results.""" + result = [] + for s in marker_strs: + toks = tokenizer.encode(s, add_special_tokens=False) + if toks: + result.append(toks) + return result + + +def _resolve_role_boundary_override(specs: list[dict], tokenizer) -> list[RoleBoundary]: + """Resolve user ``cfg.role_boundaries`` specs into RoleBoundary objects. + + The sentinel ``end == "eos_token"`` resolves to ``eos_token_id`` (used by + Pixtral/Mistral v7 templates). ``end`` null/omitted runs to end-of-sequence. + """ + out: list[RoleBoundary] = [] + for i, spec in enumerate(specs): + if hasattr(spec, "model_dump"): + d = spec.model_dump() + else: + d = dict(spec) + + role = d.get("role") + start_str = d.get("start") + if not role or start_str is None: + raise ValueError( + f"cfg.role_boundaries[{i}] must have both 'role' and 'start' " + f"(got {d!r})." + ) + start_ids = tokenizer.encode(start_str, add_special_tokens=False) + if not start_ids: + raise ValueError( + f"cfg.role_boundaries[{i}]: start marker {start_str!r} " + f"tokenizes to an empty sequence; cannot match." + ) + + end_spec = d.get("end") + if end_spec is None: + end_ids: list[int] = [] + elif end_spec == "eos_token": + eos = getattr(tokenizer, "eos_token_id", None) + if eos is None: + raise ValueError( + f"cfg.role_boundaries[{i}] requested end='eos_token' but " + "the tokenizer has no eos_token_id." + ) + end_ids = [eos] + else: + end_ids = tokenizer.encode(end_spec, add_special_tokens=False) + if not end_ids: + raise ValueError( + f"cfg.role_boundaries[{i}]: end marker {end_spec!r} " + f"tokenizes to an empty sequence; cannot match. Use " + f"end=null to run to end-of-sequence or end='eos_token' " + f"to terminate at the tokenizer's EOS." + ) + + out.append( + RoleBoundary( + role=role, + start_tokens=start_ids, + end_tokens=end_ids, + include_start=bool(d.get("include_start", False)), + include_end=bool(d.get("include_end", True)), + ) + ) + return out + + class Qwen2VLProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for Qwen2-VL""" + """Processing Strategy class for Qwen2-VL (ChatML ``<|im_start|>{role}\\n ... <|im_end|>``).""" def __init__( self, @@ -250,16 +515,44 @@ class Qwen2VLProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, + ) self.image_token = "<|image_pad|>" # nosec self.image_token_id = processor.tokenizer.convert_tokens_to_ids( self.image_token ) + def _build_role_boundaries(self) -> list[RoleBoundary]: + tok = self.processor.tokenizer + end = _encode_markers(tok, ["<|im_end|>"]) + if not end: + return [] + end_ids = end[0] + boundaries = [] + for role in ("system", "user", "assistant"): + start = _encode_markers(tok, [f"<|im_start|>{role}\n"]) + if start: + boundaries.append( + RoleBoundary(role=role, start_tokens=start[0], end_tokens=end_ids) + ) + return boundaries -class Qwen3_5ProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for Qwen3.5 (early-fusion VLM)""" + +class Qwen3_5ProcessingStrategy(Qwen2VLProcessingStrategy): + """Processing Strategy class for Qwen3.5 (Qwen2-VL boundaries + ``<|video_pad|>`` mask).""" def __init__( self, @@ -267,11 +560,20 @@ class Qwen3_5ProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) - self.image_token = "<|image_pad|>" # nosec - self.image_token_id = processor.tokenizer.convert_tokens_to_ids( - self.image_token + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, ) self.video_token = "<|video_pad|>" # nosec self.video_token_id = processor.tokenizer.convert_tokens_to_ids( @@ -280,12 +582,44 @@ class Qwen3_5ProcessingStrategy(ProcessingStrategy): def process_labels(self, input_ids): labels = super().process_labels(input_ids) - labels[labels == self.video_token_id] = -100 + if self.video_token_id is not None: + labels[labels == self.video_token_id] = -100 return labels -class Gemma3ProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for Gemma3""" +class _GemmaTurnStrategy(ProcessingStrategy): + """Gemma3/3n ``{role} ... `` (Gemma 4 uses different markers).""" + + def _build_role_boundaries(self) -> list[RoleBoundary]: + tok = self.processor.tokenizer + end = _encode_markers(tok, [""]) + if not end: + return [] + end_ids = end[0] + boundaries = [] + # Template uses 'model'; external role knob stays 'assistant'. Gemma 3 + # and Gemma 3n jinja templates fold the system message into the first + # user's content prefix and never emit 'system', so we + # don't declare a system boundary here. + role_marker_pairs = [ + ("assistant", "model"), + ("user", "user"), + ] + for external_role, template_role in role_marker_pairs: + start = _encode_markers(tok, [f"{template_role}\n"]) + if start: + boundaries.append( + RoleBoundary( + role=external_role, + start_tokens=start[0], + end_tokens=end_ids, + ) + ) + return boundaries + + +class Gemma3ProcessingStrategy(_GemmaTurnStrategy): + """Processing Strategy class for Gemma3.""" def __init__( self, @@ -293,119 +627,247 @@ class Gemma3ProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) - self.image_token = processor.tokenizer.special_tokens_map["boi_token"] - self.image_token_id = processor.tokenizer.convert_tokens_to_ids( - self.image_token + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, ) + # Real Gemma3 tokenizers expose boi_token as a direct attribute, not + # via special_tokens_map (which only holds HF's standard slots). + boi = getattr(processor.tokenizer, "boi_token", None) + if boi is not None: + self.image_token = boi + self.image_token_id = processor.tokenizer.convert_tokens_to_ids(boi) def process_labels(self, input_ids): - labels = input_ids.clone() - - # Follows https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora - labels[labels == self.processor.tokenizer.pad_token_id] = -100 - labels[labels == self.image_token_id] = -100 - labels[labels == 262144] = -100 # corresponds to - + labels = super().process_labels(input_ids) + # Resolve via tokenizer; fall back to default id + # if not in vocab. Matches Gemma4's pattern. + tok = self.processor.tokenizer + soft_id = tok.convert_tokens_to_ids("") + unk_id = getattr(tok, "unk_token_id", None) + if soft_id is not None and soft_id != unk_id: + labels[labels == soft_id] = -100 + else: + labels[labels == 262144] = -100 return labels -class Gemma3nProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for Gemma3n""" +class Gemma3nProcessingStrategy(_GemmaTurnStrategy): + """Gemma3n: same turn boundaries as Gemma3, additionally masks audio/delimiter tokens.""" - def _mask_non_assistant(self, labels: Tensor) -> Tensor: - def _find_token_sequence(label, start_pos, token_sequence): - """Check if token_sequence appears at start_pos in label""" - if start_pos + len(token_sequence) > len(label): - return False - if label[start_pos] != token_sequence[0]: - return False - return ( - label[start_pos : start_pos + len(token_sequence)].tolist() - == token_sequence - ) + def process_labels(self, input_ids): + labels = super().process_labels(input_ids) + tok = self.processor.tokenizer + # Follows huggingface-gemma-recipes fine_tune_gemma3n_on_t4 notebook. + for attr in ( + "image_token_id", + "audio_token_id", + "boi_token_id", + "eoi_token_id", + ): + tok_id = getattr(tok, attr, None) + if tok_id is not None: + labels[labels == tok_id] = -100 + return labels - def _find_assistant_end(label, start_pos, assistant_end_tok, mask, i): - """ - Find the end of assistant response and update mask accordingly - Returns new position to continue from and whether the end seq is found - """ - k = start_pos - while k < len(label): - if not _find_token_sequence(label, k, assistant_end_tok): - mask[i][k] = 1 - k += 1 - continue +class Gemma4ProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Gemma 4. - return k + len(assistant_end_tok), True + Boundary markers ``<|turn>model ... `` verified against + google/gemma-4-E2B-it. boi/eoi/boa/eoa ids are resolved via + ``convert_tokens_to_ids`` since only their string forms are on the processor. + """ - return k, False - - mask = zeros_like(labels) - - assistant_start_str = "model" - assistant_end_str = "" - include_assistant_start_tok = False - include_assistant_end_tok = True - - # str to tokens - assistant_start_tok = self.processor.tokenizer.encode( - assistant_start_str, add_special_tokens=False - ) - assistant_end_tok = self.processor.tokenizer.encode( - assistant_end_str, add_special_tokens=False - ) - - for i, label in enumerate(labels): - j = 0 - # while loop through each tok index in labels[i] - while j < len(label): - # Check until match start seq - if not _find_token_sequence(label, j, assistant_start_tok): - j += 1 - continue - - if include_assistant_start_tok: - mask[i][j : j + len(assistant_start_tok)] = 1 - - # Find where the assistant response ends - start_of_content = j + len(assistant_start_tok) - end_pos, found_end_seq = _find_assistant_end( - label, start_of_content, assistant_end_tok, mask, i + def _build_role_boundaries(self) -> list[RoleBoundary]: + tok = self.processor.tokenizer + end = _encode_markers(tok, [""]) + if not end: + return [] + end_ids = end[0] + boundaries = [] + role_marker_pairs = [ + ("assistant", "model"), + ("user", "user"), + ("system", "system"), + ] + for external_role, template_role in role_marker_pairs: + # Include trailing ``\n`` for consistency with Qwen/Gemma3/Llama + # markers; the newline is part of the marker in the real + # google/gemma-4 tokenizer's chat template. + start = _encode_markers(tok, [f"<|turn>{template_role}\n"]) + if start: + boundaries.append( + RoleBoundary( + role=external_role, + start_tokens=start[0], + end_tokens=end_ids, + ) ) - - # Include end token if requested - if include_assistant_end_tok and found_end_seq: - mask[i][end_pos - len(assistant_end_tok) : end_pos] = 1 - - j = end_pos - - labels[i][mask[i] == 0] = -100 - - return labels + return boundaries def process_labels(self, input_ids): - labels = input_ids.clone() - labels = self._mask_non_assistant(labels) + labels = super().process_labels(input_ids) - # Follows https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/fine_tune_gemma3n_on_t4.ipynb - labels[labels == self.processor.tokenizer.pad_token_id] = -100 - if hasattr(self.processor.tokenizer, "image_token_id"): - labels[labels == self.processor.tokenizer.image_token_id] = -100 - if hasattr(self.processor.tokenizer, "audio_token_id"): - labels[labels == self.processor.tokenizer.audio_token_id] = -100 - if hasattr(self.processor.tokenizer, "boi_token_id"): - labels[labels == self.processor.tokenizer.boi_token_id] = -100 - if hasattr(self.processor.tokenizer, "eoi_token_id"): - labels[labels == self.processor.tokenizer.eoi_token_id] = -100 + tokenizer = self.processor.tokenizer + unk_id = getattr(tokenizer, "unk_token_id", None) + + if getattr(tokenizer, "image_token_id", None) is not None: + labels[labels == tokenizer.image_token_id] = -100 + if getattr(tokenizer, "audio_token_id", None) is not None: + labels[labels == tokenizer.audio_token_id] = -100 + + # boi/eoi/boa/eoa are only string attrs on the processor; resolve ids here. + for attr in ("boi_token", "eoi_token", "boa_token", "eoa_token"): + token_str = getattr(self.processor, attr, None) + if token_str is None: + continue + token_id = tokenizer.convert_tokens_to_ids(token_str) + if token_id is None or token_id == unk_id: + continue + labels[labels == token_id] = -100 + + # Video id lives on the processor, not the tokenizer. + video_token_id = getattr(self.processor, "video_token_id", None) + if video_token_id is not None and video_token_id != unk_id: + labels[labels == video_token_id] = -100 return labels +class Llama3_2VisionProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Llama-3.2 Vision (``<|start_header_id|>{role}<|end_header_id|>\\n\\n ... <|eot_id|>``).""" + + def _build_role_boundaries(self) -> list[RoleBoundary]: + tok = self.processor.tokenizer + end = _encode_markers(tok, ["<|eot_id|>"]) + if not end: + return [] + end_ids = end[0] + boundaries = [] + for role in ("system", "user", "assistant", "ipython", "tool"): + start = _encode_markers( + tok, [f"<|start_header_id|>{role}<|end_header_id|>\n\n"] + ) + if start: + boundaries.append( + RoleBoundary(role=role, start_tokens=start[0], end_tokens=end_ids) + ) + return boundaries + + +class Llama4ProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Llama 4 (``<|header_start|>{role}<|header_end|>\\n\\n ... <|eot|>``).""" + + def _build_role_boundaries(self) -> list[RoleBoundary]: + tok = self.processor.tokenizer + end = _encode_markers(tok, ["<|eot|>"]) + if not end: + return [] + end_ids = end[0] + boundaries = [] + for role in ("system", "user", "assistant", "ipython", "tool"): + start = _encode_markers(tok, [f"<|header_start|>{role}<|header_end|>\n\n"]) + if start: + boundaries.append( + RoleBoundary(role=role, start_tokens=start[0], end_tokens=end_ids) + ) + return boundaries + + +class PixtralProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Pixtral (``[INST] ... [/INST]`` user, assistant terminates at ``eos_token``). + + ``[/INST]`` is shared between user-end and assistant-start. We declare user + with ``include_end=False`` so the scanner hands the ``[/INST]`` back to + assistant's start match on the next iteration. + """ + + def _build_role_boundaries(self) -> list[RoleBoundary]: + tok = self.processor.tokenizer + eos = getattr(tok, "eos_token_id", None) + if eos is None: + return [] + boundaries = [] + inst_start = _encode_markers(tok, ["[INST]"]) + inst_end = _encode_markers(tok, ["[/INST]"]) + if inst_start and inst_end: + boundaries.append( + RoleBoundary( + role="user", + start_tokens=inst_start[0], + end_tokens=inst_end[0], + include_end=False, + ) + ) + boundaries.append( + RoleBoundary( + role="assistant", + start_tokens=inst_end[0], + end_tokens=[eos], + ) + ) + return boundaries + + +class MistralV7TekkenProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Mistral v7 Tekken (Pixtral-style plus ``[SYSTEM_PROMPT]...[/SYSTEM_PROMPT]``). + + Same ``[/INST]``-shared-marker treatment as :class:`PixtralProcessingStrategy`. + """ + + def _build_role_boundaries(self) -> list[RoleBoundary]: + tok = self.processor.tokenizer + eos = getattr(tok, "eos_token_id", None) + if eos is None: + return [] + boundaries = [] + sys_start = _encode_markers(tok, ["[SYSTEM_PROMPT]"]) + sys_end = _encode_markers(tok, ["[/SYSTEM_PROMPT]"]) + if sys_start and sys_end: + boundaries.append( + RoleBoundary( + role="system", start_tokens=sys_start[0], end_tokens=sys_end[0] + ) + ) + inst_start = _encode_markers(tok, ["[INST]"]) + inst_end = _encode_markers(tok, ["[/INST]"]) + if inst_start and inst_end: + boundaries.append( + RoleBoundary( + role="user", + start_tokens=inst_start[0], + end_tokens=inst_end[0], + include_end=False, + ) + ) + boundaries.append( + RoleBoundary( + role="assistant", + start_tokens=inst_end[0], + end_tokens=[eos], + ) + ) + return boundaries + + class VoxtralProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for Voxtral""" + """Processing Strategy class for Voxtral. + + Role boundaries NOT declared — mistral-common instruct tokenizer markers + unverified. Falls back to pad+audio masking with a one-shot warning. + """ def __init__( self, @@ -413,8 +875,21 @@ class VoxtralProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, + ) special_ids = ( processor.tokenizer.tokenizer.instruct_tokenizer.audio_encoder.special_ids ) @@ -424,16 +899,25 @@ class VoxtralProcessingStrategy(ProcessingStrategy): def process_labels(self, input_ids): labels = input_ids.clone() + labels = self._mask_non_assistant(labels) - labels[labels == self.processor.tokenizer.pad_token_id] = -100 - labels[labels == self.audio_token] = -100 - labels[labels == self.begin_audio_token] = -100 + pad_id = getattr(self.processor.tokenizer, "pad_token_id", None) + if pad_id is not None: + labels[labels == pad_id] = -100 + if self.audio_token is not None: + labels[labels == self.audio_token] = -100 + if self.begin_audio_token is not None: + labels[labels == self.begin_audio_token] = -100 return labels class SmolVLM2ProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for SmolVLM2""" + """Processing Strategy class for SmolVLM2. + + Role boundaries NOT declared — SmolVLM2 chat_template varies per checkpoint + (HuggingFaceTB ships multiple variants), so we opt out rather than mis-mask. + """ def __init__( self, @@ -441,8 +925,21 @@ class SmolVLM2ProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, + ) self.image_token = "" # nosec self.image_token_id = processor.tokenizer.additional_special_tokens_ids[ @@ -451,7 +948,11 @@ class SmolVLM2ProcessingStrategy(ProcessingStrategy): class Mistral3ProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for Mistral3""" + """Processing Strategy class for Mistral3. + + Role boundaries NOT declared (mistral-common instruct tokenizer unverified); + same fallback as VoxtralProcessingStrategy. + """ def __init__( self, @@ -459,8 +960,21 @@ class Mistral3ProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, + ) special_ids = ( processor.tokenizer.tokenizer.instruct_tokenizer.image_encoder.special_ids ) @@ -471,17 +985,24 @@ class Mistral3ProcessingStrategy(ProcessingStrategy): def process_labels(self, input_ids): labels = input_ids.clone() + labels = self._mask_non_assistant(labels) - labels[labels == self.processor.tokenizer.pad_token_id] = -100 - labels[labels == self.image_token] = -100 - labels[labels == self.image_break_token] = -100 - labels[labels == self.image_end_token] = -100 + pad_id = getattr(self.processor.tokenizer, "pad_token_id", None) + if pad_id is not None: + labels[labels == pad_id] = -100 + for tok_id in (self.image_token, self.image_break_token, self.image_end_token): + if tok_id is not None: + labels[labels == tok_id] = -100 return labels class InternVLProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for InternVL""" + """Processing Strategy class for InternVL. + + Role boundaries NOT declared (InternLM-style template unverified); falls + back to pad + image-id masking with a one-shot warning. + """ def __init__( self, @@ -489,8 +1010,21 @@ class InternVLProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, + ) if not hasattr(processor, "image_ids"): raise ValueError("'image_ids' missing from InternVL Processor.") @@ -499,20 +1033,26 @@ class InternVLProcessingStrategy(ProcessingStrategy): def process_labels(self, input_ids): labels = input_ids.clone() + labels = self._mask_non_assistant(labels) - labels[labels == self.processor.tokenizer.pad_token_id] = -100 + pad_id = getattr(self.processor.tokenizer, "pad_token_id", None) + if pad_id is not None: + labels[labels == pad_id] = -100 for ids in self.image_token_ids: - labels[labels == ids] = -100 - - # Note: Check if need to mask 'video_token' as it gets converted to - # image patches during media processing + if ids is not None: + labels[labels == ids] = -100 + # Video tokens get converted to image patches during media processing; masking may be redundant. return labels class Glm4vProcessingStrategy(ProcessingStrategy): - """Processing Strategy class for GLM4V and GLM4V-MoE vision models.""" + """Shared strategy for Glm4vProcessor (GLM-4V / GLM-4.1V) and + Glm46VProcessor (GLM-4.6V / GLM-4.7V) — identical media-token markers. + + Role boundaries unverified; use cfg.role_boundaries to enable masking. + """ def __init__( self, @@ -520,8 +1060,21 @@ class Glm4vProcessingStrategy(ProcessingStrategy): chat_template: Optional[str] = None, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - super().__init__(processor, chat_template, image_size, image_resize_algorithm) + super().__init__( + processor, + chat_template, + image_size, + image_resize_algorithm, + train_on_inputs=train_on_inputs, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + role_boundaries_override=role_boundaries_override, + ) self.tokenizer = getattr(processor, "tokenizer", processor) @@ -549,16 +1102,22 @@ class Glm4vProcessingStrategy(ProcessingStrategy): def process_labels(self, input_ids): labels = input_ids.clone() + labels = self._mask_non_assistant(labels) - labels[labels == self.tokenizer.pad_token_id] = -100 + pad_id = getattr(self.tokenizer, "pad_token_id", None) + if pad_id is not None: + labels[labels == pad_id] = -100 - labels[labels == self.image_token_id] = -100 - labels[labels == self.begin_image_token_id] = -100 - labels[labels == self.end_image_token_id] = -100 - - labels[labels == self.video_token_id] = -100 - labels[labels == self.begin_video_token_id] = -100 - labels[labels == self.end_video_token_id] = -100 + for tok_id in ( + self.image_token_id, + self.begin_image_token_id, + self.end_image_token_id, + self.video_token_id, + self.begin_video_token_id, + self.end_video_token_id, + ): + if tok_id is not None: + labels[labels == tok_id] = -100 return labels @@ -569,14 +1128,20 @@ def get_processing_strategy( chat_template_type, image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, + train_on_inputs: bool = False, + roles_to_train: Optional[list[str]] = None, + train_on_eos: Optional[str] = None, + role_boundaries_override: Optional[list[dict]] = None, ): - from axolotl.utils.mistral.mistral3_processor import Mistral3Processor - processing_kwargs = { "processor": processor, "chat_template": chat_template, "image_size": image_size, "image_resize_algorithm": image_resize_algorithm, + "train_on_inputs": train_on_inputs, + "roles_to_train": roles_to_train, + "train_on_eos": train_on_eos, + "role_boundaries_override": role_boundaries_override, } if chat_template_type in [None, "tokenizer_default"]: @@ -585,53 +1150,63 @@ def get_processing_strategy( processing_kwargs["chat_template"] = tokenizer.chat_template if chat_template_type == "qwen2_vl": - return Qwen2VLProcessingStrategy( - **processing_kwargs, - ) - if chat_template_type in ["qwen3_5", "qwen3_5_moe"]: - return Qwen3_5ProcessingStrategy( - **processing_kwargs, - ) + return Qwen2VLProcessingStrategy(**processing_kwargs) + if chat_template_type == "qwen3_5": + return Qwen3_5ProcessingStrategy(**processing_kwargs) if chat_template_type == "gemma3": - return Gemma3ProcessingStrategy( - **processing_kwargs, - ) + return Gemma3ProcessingStrategy(**processing_kwargs) if chat_template_type == "gemma3n": - return Gemma3nProcessingStrategy( - **processing_kwargs, - ) + return Gemma3nProcessingStrategy(**processing_kwargs) + if chat_template_type == "gemma4": + return Gemma4ProcessingStrategy(**processing_kwargs) + if chat_template_type == "llama3_2_vision": + return Llama3_2VisionProcessingStrategy(**processing_kwargs) + if chat_template_type == "llama4": + return Llama4ProcessingStrategy(**processing_kwargs) + if chat_template_type == "pixtral": + return PixtralProcessingStrategy(**processing_kwargs) + if chat_template_type == "mistral_v7_tekken": + return MistralV7TekkenProcessingStrategy(**processing_kwargs) if isinstance(processor, VoxtralProcessor): - return VoxtralProcessingStrategy( - **processing_kwargs, - ) + return VoxtralProcessingStrategy(**processing_kwargs) if isinstance(processor, SmolVLMProcessor): - return SmolVLM2ProcessingStrategy( - **processing_kwargs, + return SmolVLM2ProcessingStrategy(**processing_kwargs) + + # Lazy import: mistral_common is optional. Mirrors the Glm46V pattern below. + try: + from axolotl.utils.mistral.mistral3_processor import Mistral3Processor + + if isinstance(processor, Mistral3Processor): + return Mistral3ProcessingStrategy(**processing_kwargs) + except (ImportError, ModuleNotFoundError) as exc: + LOG.debug( + "Mistral3Processor import failed; Mistral3 strategy will be unavailable: %r", + exc, ) - if isinstance(processor, Mistral3Processor): - return Mistral3ProcessingStrategy( - **processing_kwargs, - ) + # Both Glm4vProcessor and Glm46VProcessor share markers; route to the same + # strategy. Independent try/except so either can be absent. + try: + from transformers.models.glm4v.processing_glm4v import Glm4vProcessor + + if isinstance(processor, Glm4vProcessor): + return Glm4vProcessingStrategy(**processing_kwargs) + except (ImportError, ModuleNotFoundError) as exc: + LOG.debug("Glm4vProcessor import failed: %r", exc) + try: from transformers.models.glm46v.processing_glm46v import Glm46VProcessor if isinstance(processor, Glm46VProcessor): - return Glm4vProcessingStrategy( - **processing_kwargs, - ) - except ImportError: - pass + return Glm4vProcessingStrategy(**processing_kwargs) + except (ImportError, ModuleNotFoundError) as exc: + LOG.debug("Glm46VProcessor import failed: %r", exc) if isinstance(processor, InternVLProcessor): - return InternVLProcessingStrategy( - **processing_kwargs, - ) + return InternVLProcessingStrategy(**processing_kwargs) - # llama3_2_vision, llama4, llava - # mistral_v7_tekken, pixtral, lfm2vl - return ProcessingStrategy( - **processing_kwargs, - ) + # Unregistered templates (llava, lfm2vl, mistral_v3_tekken, ...) use the + # base strategy; it warns once when train_on_inputs=False. + return ProcessingStrategy(**processing_kwargs) diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py index 6114a63e0..97ed71631 100644 --- a/src/axolotl/utils/schemas/datasets.py +++ b/src/axolotl/utils/schemas/datasets.py @@ -166,10 +166,10 @@ class SFTDataset(BaseModel): "description": "Roles to train on. The tokens from these roles will be considered for the loss." }, ) - train_on_eos: Literal["all", "turn", "last"] | None = Field( + train_on_eos: Literal["all", "turn", "last", "none"] | None = Field( default=None, json_schema_extra={ - "description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation" + "description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation, none: never train on EOS tokens" }, ) roles: dict[str, list[str]] | None = Field( diff --git a/src/axolotl/utils/schemas/multimodal.py b/src/axolotl/utils/schemas/multimodal.py index a3449199f..01ad5e5a3 100644 --- a/src/axolotl/utils/schemas/multimodal.py +++ b/src/axolotl/utils/schemas/multimodal.py @@ -6,6 +6,57 @@ from PIL.Image import Resampling from pydantic import BaseModel, Field, field_validator +class RoleBoundarySpec(BaseModel): + """One ``cfg.role_boundaries`` row; see docs/multimodal_assistant_mask.md.""" + + role: str = Field( + json_schema_extra={ + "description": ( + "Role name as it appears in cfg.roles_to_train (e.g. " + "'assistant', 'user', 'system', 'tool', 'ipython')." + ) + }, + ) + start: str = Field( + json_schema_extra={ + "description": ( + "Literal string that marks the start of this role's span in " + "the rendered chat template. Tokenized via " + "``tokenizer.encode(..., add_special_tokens=False)`` at " + "strategy init." + ) + }, + ) + end: str | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Literal string that marks the end of this role's span. " + "Set to ``eos_token`` to terminate at the tokenizer's EOS. " + "Leave unset / null to terminate at end-of-sequence." + ) + }, + ) + include_start: bool = Field( + default=False, + json_schema_extra={ + "description": ( + "Whether the start marker tokens contribute to loss on " + "trainable turns. Default False." + ) + }, + ) + include_end: bool = Field( + default=True, + json_schema_extra={ + "description": ( + "Whether the end marker tokens contribute to loss on " + "trainable turns (honoring cfg.train_on_eos). Default True." + ) + }, + ) + + class MultiModalConfig(BaseModel): """Multi-modal configuration subset""" @@ -26,6 +77,17 @@ class MultiModalConfig(BaseModel): "description": "The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details." }, ) + role_boundaries: list[RoleBoundarySpec] | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Opt-in override for the MM mask scanner's per-role boundary " + "markers. Non-empty list replaces built-ins wholesale; unset " + "or empty falls back to built-ins. See " + "docs/multimodal_assistant_mask.md." + ) + }, + ) @field_validator("image_resize_algorithm", mode="before") @classmethod diff --git a/tests/test_processing_strategies.py b/tests/test_processing_strategies.py new file mode 100644 index 000000000..2d8f13fe5 --- /dev/null +++ b/tests/test_processing_strategies.py @@ -0,0 +1,1164 @@ +"""Tests for ``axolotl.processing_strategies`` using fake tokenizers (offline/CI-safe).""" + +import logging + +import pytest +import torch +from pydantic import ValidationError + +from axolotl.processing_strategies import ( + Gemma3nProcessingStrategy, + Gemma3ProcessingStrategy, + Gemma4ProcessingStrategy, + Llama3_2VisionProcessingStrategy, + Llama4ProcessingStrategy, + MistralV7TekkenProcessingStrategy, + PixtralProcessingStrategy, + ProcessingStrategy, + Qwen2VLProcessingStrategy, + Qwen3_5ProcessingStrategy, + RoleBoundary, + _apply_role_boundaries, + get_processing_strategy, +) + + +@pytest.fixture +def axolotl_caplog(caplog): + """caplog that also captures records from the ``axolotl`` logger. + + The axolotl logger sets ``propagate=False`` once ``configure_logging()`` is + called (which happens indirectly in many CI test paths), so the default + caplog handler installed on the root logger never sees these records. + Attaching ``caplog.handler`` to ``axolotl.processing_strategies`` directly + makes assertions reliable regardless of whether ``configure_logging`` has + already run on this worker. + """ + logger = logging.getLogger("axolotl.processing_strategies") + logger.addHandler(caplog.handler) + previous_level = logger.level + logger.setLevel(logging.DEBUG) + try: + yield caplog + finally: + logger.removeHandler(caplog.handler) + logger.setLevel(previous_level) + + +# --------------------------------------------------------------------------- # +# Generic fake tokenizer/processor scaffold +# --------------------------------------------------------------------------- # + + +class _Tokenizer: + """Minimal tokenizer stub; ``vocab`` maps marker strings to their id lists.""" + + def __init__( + self, + vocab: dict[str, list[int]], + pad_id: int = 0, + unk_id: int = 3, + eos_id: int | None = None, + ): + self.vocab = vocab + self._reverse = {} + for tok, ids in vocab.items(): + if len(ids) == 1: + self._reverse[ids[0]] = tok + self.pad_token_id = pad_id + self.unk_token_id = unk_id + if eos_id is not None: + self.eos_token_id = eos_id + + def encode(self, text, add_special_tokens=False): + # Unknown markers return [] so _encode_markers drops them silently. + return list(self.vocab.get(text, [])) + + def convert_tokens_to_ids(self, token): + v = self.vocab.get(token) + if v is None: + return self.unk_token_id + return v[0] if len(v) == 1 else self.unk_token_id + + +class _Processor: + def __init__(self, tokenizer: _Tokenizer): + self.tokenizer = tokenizer + + +# --------------------------------------------------------------------------- # +# Base scanner tests (train_on_inputs / roles_to_train / train_on_eos) +# --------------------------------------------------------------------------- # + + +def _scan(role_boundaries, seq, roles_to_train=("assistant",), train_on_eos="turn"): + labels = torch.tensor([seq]) + return _apply_role_boundaries( + labels, role_boundaries, set(roles_to_train), train_on_eos + ).tolist()[0] + + +def test_scanner_assistant_only_basic(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]), + ] + seq = [1, 3, 7, 9, 1, 2, 8, 8, 9, 5] + out = _scan(boundaries, seq) + assert out == [-100, -100, -100, -100, -100, -100, 8, 8, 9, -100] + + +def test_scanner_train_on_eos_none_excludes_end_marker(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + ] + seq = [1, 2, 8, 8, 9] + out = _scan(boundaries, seq, train_on_eos="none") + assert out == [-100, -100, 8, 8, -100] + + +def test_scanner_train_on_eos_all_keeps_non_assistant_end_marker(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]), + ] + seq = [1, 3, 7, 9, 1, 2, 8, 9] + out = _scan(boundaries, seq, train_on_eos="all") + assert out == [-100, -100, -100, 9, -100, -100, 8, 9] + + +def test_scanner_train_on_eos_all_with_non_trainable_include_end_false(): + """Non-trainable + include_end=False must not leak end marker on 'all'.""" + boundaries = [ + RoleBoundary( + role="user", + start_tokens=[50], + end_tokens=[51], + include_end=False, # shared with assistant-start + ), + RoleBoundary( + role="assistant", + start_tokens=[51], + end_tokens=[99], # eos + ), + ] + seq = [50, 7, 51, 8, 8, 99] # [INST] 7 [/INST] 8 8 EOS + out = _scan(boundaries, seq, roles_to_train=("assistant",), train_on_eos="all") + # [/INST] at idx 2 must stay masked — user.include_end=False says so. + assert out == [-100, -100, -100, 8, 8, 99] + + +def test_scanner_roles_to_train_user_and_assistant(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]), + ] + seq = [1, 3, 7, 9, 1, 2, 8, 9] + out = _scan(boundaries, seq, roles_to_train=("user", "assistant")) + # include_start defaults to False so role-start markers stay masked. + assert out == [-100, -100, 7, 9, -100, -100, 8, 9] + + +def test_scanner_truncated_assistant(): + """Missing end marker: span runs to end-of-sequence, end marker not emitted.""" + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + ] + seq = [1, 2, 8, 8, 8] + out = _scan(boundaries, seq) + assert out == [-100, -100, 8, 8, 8] + + +def test_scanner_longest_prefix_wins(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2, 4], end_tokens=[9]), + RoleBoundary(role="user", start_tokens=[1, 2], end_tokens=[9]), + ] + seq = [1, 2, 4, 8, 9] + out = _scan(boundaries, seq) + assert out == [-100, -100, -100, 8, 9] + + +def test_scanner_no_boundaries_masks_everything(): + # Strategies short-circuit this in _mask_non_assistant; see test_base_strategy_warns_when_no_boundaries. + labels = torch.tensor([[1, 2, 3, 4]]) + out = _apply_role_boundaries(labels, [], {"assistant"}, "turn") + assert out.tolist() == [[-100, -100, -100, -100]] + + +def test_scanner_train_on_eos_last_only_final_trainable_turn(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + ] + seq = [1, 2, 5, 9, 1, 2, 6, 9] + out = _scan(boundaries, seq, train_on_eos="last") + # Only the second assistant turn's end marker (index 7) is kept. + assert out == [-100, -100, 5, -100, -100, -100, 6, 9] + + +def test_scanner_train_on_eos_last_no_trainable_turn_is_noop(): + boundaries = [ + RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]), + ] + seq = [1, 3, 5, 9, 1, 3, 6, 9] + out = _scan(boundaries, seq, roles_to_train=("assistant",), train_on_eos="last") + assert out == [-100] * 8 + + +def test_strategy_rejects_unknown_train_on_eos(): + vocab = {"BOA": [50], "EOT": [60]} + with pytest.raises(ValueError, match="train_on_eos"): + ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + train_on_eos="bogus", + role_boundaries_override=[ + {"role": "assistant", "start": "BOA", "end": "EOT"} + ], + ) + + +def test_strategy_accepts_all_supported_train_on_eos_values(): + vocab = {"BOA": [50], "EOT": [60]} + for val in ("turn", "all", "none", "last"): + ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + train_on_eos=val, + role_boundaries_override=[ + {"role": "assistant", "start": "BOA", "end": "EOT"} + ], + ) + + +def test_empty_role_boundaries_override_falls_back_to_builtin(): + """Empty override must fall through to built-ins (opt-in semantics).""" + vocab = { + "<|im_start|>assistant\n": [101, 102, 103], + "<|im_start|>user\n": [101, 106, 103], + "<|im_end|>": [104], + } + strat_empty = Qwen2VLProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[], + ) + strat_default = Qwen2VLProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + ) + # Empty override === no override: both strategies keep the built-in boundaries. + assert strat_empty.role_boundaries == strat_default.role_boundaries + assert len(strat_empty.role_boundaries) > 0 # sanity: built-ins are non-empty + + +def test_sft_dataset_schema_accepts_all_supported_train_on_eos_values(): + """SFTDataset.train_on_eos accepts every value the scanner honors.""" + from axolotl.utils.schemas.datasets import SFTDataset + + for val in ("all", "turn", "last", "none"): + ds = SFTDataset(path="dummy", type="chat_template", train_on_eos=val) + assert ds.train_on_eos == val + + with pytest.raises(ValidationError): + SFTDataset(path="dummy", type="chat_template", train_on_eos="bogus") + + +def test_strategy_init_logs_resolved_masking_config_builtin(axolotl_caplog): + vocab = { + "<|im_start|>assistant\n": [101, 102, 103], + "<|im_start|>user\n": [101, 106, 103], + "<|im_end|>": [104], + } + with axolotl_caplog.at_level(logging.INFO, logger="axolotl.processing_strategies"): + Qwen2VLProcessingStrategy(_Processor(_Tokenizer(vocab, pad_id=0))) + msgs = [r.getMessage() for r in axolotl_caplog.records] + assert any( + "ProcessingStrategy init" in m + and "Qwen2VLProcessingStrategy" in m + and "boundaries_source=built-in" in m + for m in msgs + ) + + +def test_strategy_init_logs_resolved_masking_config_override(axolotl_caplog): + vocab = {"BOA": [50, 51], "EOT": [60]} + with axolotl_caplog.at_level(logging.INFO, logger="axolotl.processing_strategies"): + ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + {"role": "assistant", "start": "BOA", "end": "EOT"}, + ], + ) + msgs = [r.getMessage() for r in axolotl_caplog.records] + # Resolved start/end ids must appear in the log so users can verify what + # was actually matched. + assert any( + "ProcessingStrategy init" in m + and "boundaries_source=override" in m + and "[50, 51]" in m + and "[60]" in m + for m in msgs + ) + + +def test_process_labels_no_warning_when_image_token_id_none(): + """image_token_id=None must not trigger a UserWarning from ``labels == None``.""" + import warnings + + vocab = {"BOA": [50], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[{"role": "assistant", "start": "BOA", "end": "EOT"}], + ) + assert strategy.image_token_id is None + with warnings.catch_warnings(): + warnings.simplefilter("error") + strategy.process_labels(torch.tensor([[1, 50, 2, 3, 60]])) + + +def test_roles_to_train_empty_list_masks_everything(): + """An explicit empty list is distinct from None and disables all roles.""" + vocab = {"BOA": [50], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + roles_to_train=[], + role_boundaries_override=[{"role": "assistant", "start": "BOA", "end": "EOT"}], + ) + assert strategy.roles_to_train == [] + seq = [1, 50, 7, 8, 60, 9] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100] * 6 + + +# --------------------------------------------------------------------------- # +# Qwen2VL / Qwen3.5 +# --------------------------------------------------------------------------- # + + +def _qwen_tokenizer(): + # ChatML-ish with image_pad=200, video_pad=201. + vocab = { + "<|im_start|>assistant\n": [101, 102, 103], + "<|im_start|>user\n": [101, 106, 103], + "<|im_start|>system\n": [101, 105, 103], + "<|im_end|>": [104], + "<|image_pad|>": [200], + "<|video_pad|>": [201], + } + return _Tokenizer(vocab, pad_id=0) + + +def _make_qwen2vl(): + tok = _qwen_tokenizer() + return Qwen2VLProcessingStrategy(_Processor(tok)) + + +def test_qwen2vl_masks_user_keeps_assistant_and_image_pad(): + strategy = _make_qwen2vl() + seq = [ + 101, + 105, + 103, + 77, + 104, + 101, + 106, + 103, + 7, + 104, + 101, + 102, + 103, + 200, + 8, + 104, + ] + labels = strategy.process_labels(torch.tensor([seq])) + out = labels.tolist()[0] + assert out[:10] == [-100] * 10 + assert out[10] == -100 and out[11] == -100 and out[12] == -100 + assert out[13] == -100 # image_pad masked post-scan + assert out[14] == 8 + assert out[15] == 104 + + +def test_qwen3_5_masks_video_pad_too(): + tok = _qwen_tokenizer() + strategy = Qwen3_5ProcessingStrategy(_Processor(tok)) + seq = [101, 102, 103, 201, 8, 104] + labels = strategy.process_labels(torch.tensor([seq])) + assert labels.tolist()[0] == [-100, -100, -100, -100, 8, 104] + + +def test_qwen2vl_train_on_inputs_true_keeps_everything(): + tok = _qwen_tokenizer() + strategy = Qwen2VLProcessingStrategy(_Processor(tok), train_on_inputs=True) + seq = [101, 106, 103, 7, 104, 101, 102, 103, 8, 104] + labels = strategy.process_labels(torch.tensor([seq])) + assert labels.tolist()[0] == seq + + +# --------------------------------------------------------------------------- # +# Gemma3 / Gemma3n +# --------------------------------------------------------------------------- # + + +def _gemma_tokenizer(): + vocab = { + "model\n": [1, 2, 3], + "user\n": [1, 10, 3], + "system\n": [1, 11, 3], + "": [4], + "": [50], # boi_token for Gemma3 + } + tok = _Tokenizer(vocab, pad_id=0) + # boi_token is a direct tokenizer attribute on real Gemma3. + tok.boi_token = "" + return tok + + +def test_gemma3_scanner_plus_soft_image_token(): + strategy = Gemma3ProcessingStrategy(_Processor(_gemma_tokenizer())) + seq = [1, 10, 3, 7, 4, 1, 2, 3, 50, 8, 262144, 4] + labels = strategy.process_labels(torch.tensor([seq])) + # boi(50) and soft-image-token(262144) masked post-scan. + assert labels.tolist()[0] == [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 8, + -100, + 4, + ] + + +def test_gemma3n_masks_image_and_audio_attrs(): + tok = _gemma_tokenizer() + # Gemma3n exposes these as integer attrs on the tokenizer. + tok.image_token_id = 70 + tok.audio_token_id = 71 + tok.boi_token_id = 72 + tok.eoi_token_id = 73 + strategy = Gemma3nProcessingStrategy(_Processor(tok)) + seq = [1, 2, 3, 70, 71, 72, 73, 9, 4] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, -100, -100, -100, -100, -100, 9, 4] + + +# --------------------------------------------------------------------------- # +# Gemma 4 +# --------------------------------------------------------------------------- # + + +class _FakeGemma4Tokenizer(_Tokenizer): + """Mirrors google/gemma-4-E2B-it token layout. Gemma4 role-start markers + include the trailing newline so the boundary matches the jinja template.""" + + VOCAB = { + "<|turn>model\n": [105, 4368, 108], + "<|turn>user\n": [105, 7777, 108], + "<|turn>system\n": [105, 8888, 108], + "": [106], + "<|image|>": [258880], + "<|video|>": [258884], + "<|audio|>": [258881], + "<|image>": [255999], + "": [258882], + "<|audio>": [256000], + "": [258883], + } + + def __init__(self): + # Pass a fresh dict so per-instance mutations (should any future + # code path introduce them) cannot leak across tests via the + # shared class-level VOCAB. + super().__init__( + {token: list(ids) for token, ids in self.VOCAB.items()}, + pad_id=0, + unk_id=3, + ) + + +class _FakeGemma4Processor: + def __init__(self): + self.tokenizer = _FakeGemma4Tokenizer() + self.tokenizer.image_token_id = self.tokenizer.vocab["<|image|>"][0] + self.tokenizer.audio_token_id = self.tokenizer.vocab["<|audio|>"][0] + self.image_token = "<|image|>" + self.image_token_id = self.tokenizer.vocab["<|image|>"][0] + self.boi_token = "<|image>" + self.eoi_token = "" + self.video_token = "<|video|>" + self.video_token_id = self.tokenizer.vocab["<|video|>"][0] + self.audio_token = "<|audio|>" + self.audio_token_id = self.tokenizer.vocab["<|audio|>"][0] + self.boa_token = "<|audio>" + self.eoa_token = "" + + +def test_gemma4_masks_everything_outside_assistant_span(): + strategy = Gemma4ProcessingStrategy(_FakeGemma4Processor()) + V = strategy.processor.tokenizer.vocab + user_start = V["<|turn>user\n"] + model_start = V["<|turn>model\n"] + turn_end = V[""][0] + seq = [ + 0, + *user_start, + 4444, + turn_end, + *model_start, + 5555, + turn_end, + 9999, + ] + labels = strategy.process_labels(torch.tensor([seq])) + expected = [-100] * (1 + len(user_start) + 1 + 1 + len(model_start)) + [ + 5555, + turn_end, + -100, + ] + assert labels.tolist()[0] == expected + + +def test_gemma4_masks_media_tokens_inside_assistant_span(): + strategy = Gemma4ProcessingStrategy(_FakeGemma4Processor()) + V = strategy.processor.tokenizer.vocab + model_start = V["<|turn>model\n"] + media = [ + V["<|image|>"][0], + V["<|video|>"][0], + V["<|audio|>"][0], + V["<|image>"][0], + V[""][0], + V["<|audio>"][0], + V[""][0], + ] + turn_end = V[""][0] + seq = [*model_start, *media, 9999, turn_end] + labels = strategy.process_labels(torch.tensor([seq])) + expected = [-100] * (len(model_start) + len(media)) + [9999, turn_end] + assert labels.tolist()[0] == expected + + +def test_gemma4_multiple_assistant_turns(): + strategy = Gemma4ProcessingStrategy(_FakeGemma4Processor()) + V = strategy.processor.tokenizer.vocab + turn_end = V[""][0] + + def user_turn(x): + return [*V["<|turn>user\n"], x, turn_end] + + def model_turn(x): + return [*V["<|turn>model\n"], x, turn_end] + + seq = user_turn(1111) + model_turn(2222) + user_turn(3333) + model_turn(4444) + labels = strategy.process_labels(torch.tensor([seq])) + kept = [t for t in labels.tolist()[0] if t != -100] + assert kept == [2222, turn_end, 4444, turn_end] + + +# --------------------------------------------------------------------------- # +# Llama 3.2 Vision / Llama 4 +# --------------------------------------------------------------------------- # + + +def test_llama3_2_vision_assistant_masking(): + vocab = { + "<|start_header_id|>assistant<|end_header_id|>\n\n": [1, 2, 3, 4, 5], + "<|start_header_id|>user<|end_header_id|>\n\n": [1, 2, 6, 4, 5], + "<|start_header_id|>system<|end_header_id|>\n\n": [1, 2, 7, 4, 5], + "<|start_header_id|>tool<|end_header_id|>\n\n": [1, 2, 8, 4, 5], + "<|start_header_id|>ipython<|end_header_id|>\n\n": [1, 2, 9, 4, 5], + "<|eot_id|>": [10], + } + strategy = Llama3_2VisionProcessingStrategy(_Processor(_Tokenizer(vocab, pad_id=0))) + seq = [1, 2, 6, 4, 5, 11, 10, 1, 2, 3, 4, 5, 12, 10] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100] * 12 + [12, 10] + + +def test_llama4_assistant_masking(): + vocab = { + "<|header_start|>assistant<|header_end|>\n\n": [20, 21, 22, 23], + "<|header_start|>user<|header_end|>\n\n": [20, 21, 24, 23], + "<|header_start|>system<|header_end|>\n\n": [20, 21, 25, 23], + "<|header_start|>tool<|header_end|>\n\n": [20, 21, 26, 23], + "<|header_start|>ipython<|header_end|>\n\n": [20, 21, 27, 23], + "<|eot|>": [30], + } + strategy = Llama4ProcessingStrategy(_Processor(_Tokenizer(vocab, pad_id=0))) + seq = [20, 21, 24, 23, 100, 30, 20, 21, 22, 23, 200, 30] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100] * 10 + [200, 30] + + +# --------------------------------------------------------------------------- # +# Pixtral / Mistral v7 Tekken (eos-terminated assistant) +# --------------------------------------------------------------------------- # + + +def test_pixtral_assistant_terminates_at_eos(): + # [/INST] is both user-end and assistant-start. Scanner backs up when + # user.include_end=False so the next iteration picks [/INST] up as + # assistant-start (Pixtral-specific handling in _build_role_boundaries). + vocab = { + "[INST]": [50], + "[/INST]": [51], + } + tok = _Tokenizer(vocab, pad_id=0, eos_id=99) + strategy = PixtralProcessingStrategy(_Processor(tok)) + seq = [50, 7, 51, 8, 8, 99] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + # Full-sequence expectation: user span masked; assistant content + eos kept. + assert out == [-100, -100, -100, 8, 8, 99] + + +def test_mistral_v7_tekken_system_user_assistant(): + vocab = { + "[SYSTEM_PROMPT]": [40], + "[/SYSTEM_PROMPT]": [41], + "[INST]": [50], + "[/INST]": [51], + } + tok = _Tokenizer(vocab, pad_id=0, eos_id=99) + strategy = MistralV7TekkenProcessingStrategy(_Processor(tok)) + seq = [40, 5, 41, 50, 7, 51, 8, 99] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + # Full-sequence expectation: system + user spans masked; assistant kept. + assert out == [-100, -100, -100, -100, -100, -100, 8, 99] + + +def test_pixtral_train_on_eos_all_respects_user_include_end_false(): + """Pixtral [/INST] (user-end include_end=False) stays masked on 'all'.""" + vocab = {"[INST]": [50], "[/INST]": [51]} + tok = _Tokenizer(vocab, pad_id=0, eos_id=99) + strategy = PixtralProcessingStrategy(_Processor(tok), train_on_eos="all") + seq = [50, 7, 51, 8, 8, 99] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + # [/INST] at idx 2 must stay masked — user.include_end=False says so. + # Assistant content (8, 8) + EOS (99) are unmasked as normal. + assert out == [-100, -100, -100, 8, 8, 99] + + +def test_mistral_v7_tekken_train_on_eos_all_respects_user_include_end_false(): + """System end (include_end=True) unmasked on 'all'; [/INST] stays masked.""" + vocab = { + "[SYSTEM_PROMPT]": [40], + "[/SYSTEM_PROMPT]": [41], + "[INST]": [50], + "[/INST]": [51], + } + tok = _Tokenizer(vocab, pad_id=0, eos_id=99) + strategy = MistralV7TekkenProcessingStrategy(_Processor(tok), train_on_eos="all") + seq = [40, 5, 41, 50, 7, 51, 8, 99] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + # system content masked, [/SYSTEM_PROMPT]=41 kept (include_end=True + all); + # user + [/INST]=51 masked (include_end=False); assistant 8 + eos 99 kept. + assert out == [-100, -100, 41, -100, -100, -100, 8, 99] + + +# --------------------------------------------------------------------------- # +# Dispatcher routing +# --------------------------------------------------------------------------- # + + +def _dispatch(processor, chat_template_type): + return get_processing_strategy( + processor=processor, + chat_template=None, + chat_template_type=chat_template_type, + ) + + +def test_dispatch_qwen2_vl(): + s = _dispatch(_Processor(_qwen_tokenizer()), "qwen2_vl") + assert isinstance(s, Qwen2VLProcessingStrategy) + + +def test_dispatch_qwen3_5(): + s = _dispatch(_Processor(_qwen_tokenizer()), "qwen3_5") + assert isinstance(s, Qwen3_5ProcessingStrategy) + + +def test_dispatch_gemma3(): + s = _dispatch(_Processor(_gemma_tokenizer()), "gemma3") + assert isinstance(s, Gemma3ProcessingStrategy) + + +def test_dispatch_gemma3n(): + s = _dispatch(_Processor(_gemma_tokenizer()), "gemma3n") + assert isinstance(s, Gemma3nProcessingStrategy) + + +def test_dispatch_gemma4(): + s = _dispatch(_FakeGemma4Processor(), "gemma4") + assert isinstance(s, Gemma4ProcessingStrategy) + + +def test_dispatch_llama3_2_vision(): + vocab = { + "<|start_header_id|>assistant<|end_header_id|>\n\n": [1, 2, 3, 4, 5], + "<|eot_id|>": [10], + } + s = _dispatch(_Processor(_Tokenizer(vocab, pad_id=0)), "llama3_2_vision") + assert isinstance(s, Llama3_2VisionProcessingStrategy) + + +def test_dispatch_llama4(): + vocab = { + "<|header_start|>assistant<|header_end|>\n\n": [20, 21, 22, 23], + "<|eot|>": [30], + } + s = _dispatch(_Processor(_Tokenizer(vocab, pad_id=0)), "llama4") + assert isinstance(s, Llama4ProcessingStrategy) + + +def test_dispatch_pixtral(): + vocab = {"[INST]": [50], "[/INST]": [51]} + s = _dispatch(_Processor(_Tokenizer(vocab, pad_id=0, eos_id=99)), "pixtral") + assert isinstance(s, PixtralProcessingStrategy) + + +def test_dispatch_mistral_v7_tekken(): + vocab = { + "[INST]": [50], + "[/INST]": [51], + "[SYSTEM_PROMPT]": [40], + "[/SYSTEM_PROMPT]": [41], + } + s = _dispatch( + _Processor(_Tokenizer(vocab, pad_id=0, eos_id=99)), "mistral_v7_tekken" + ) + assert isinstance(s, MistralV7TekkenProcessingStrategy) + + +def test_dispatch_unknown_falls_back_to_base(): + vocab = {"dummy": [1]} + s = _dispatch(_Processor(_Tokenizer(vocab, pad_id=0)), "llava") + assert type(s) is ProcessingStrategy + + +def _glm_vision_processor(cls_path): + """Spec'd MagicMock so isinstance(mock, cls) passes without real HF files.""" + from importlib import import_module + from unittest.mock import MagicMock + + mod_name, cls_name = cls_path.rsplit(".", 1) + cls = getattr(import_module(mod_name), cls_name) + + vocab = { + "<|image|>": [200], + "<|begin_of_image|>": [201], + "<|end_of_image|>": [202], + "<|video|>": [210], + "<|begin_of_video|>": [211], + "<|end_of_video|>": [212], + } + tok = _Tokenizer(vocab, pad_id=0) + proc = MagicMock(spec=cls) + proc.tokenizer = tok + # Drop processor.image_token so base class skips its probe. + del proc.image_token + return proc + + +def test_dispatch_glm4v_via_Glm4vProcessor(): + """Glm4vProcessor (GLM-4V) routes to Glm4vProcessingStrategy.""" + pytest.importorskip("transformers.models.glm4v.processing_glm4v") + from axolotl.processing_strategies import Glm4vProcessingStrategy + + proc = _glm_vision_processor( + "transformers.models.glm4v.processing_glm4v.Glm4vProcessor" + ) + s = _dispatch(proc, None) + assert isinstance(s, Glm4vProcessingStrategy) + + +def test_dispatch_glm4v_via_Glm46VProcessor(): + """Glm46VProcessor (GLM-4.6V) also routes to Glm4vProcessingStrategy.""" + pytest.importorskip("transformers.models.glm46v.processing_glm46v") + from axolotl.processing_strategies import Glm4vProcessingStrategy + + proc = _glm_vision_processor( + "transformers.models.glm46v.processing_glm46v.Glm46VProcessor" + ) + s = _dispatch(proc, None) + assert isinstance(s, Glm4vProcessingStrategy) + + +# --------------------------------------------------------------------------- # +# Config-based role-boundary override +# --------------------------------------------------------------------------- # + + +def test_role_boundaries_override_replaces_built_in(): + """Override swaps the built-in boundaries wholesale, not additively.""" + vocab = { + "<|im_start|>assistant\n": [101, 102, 103], + "<|im_start|>user\n": [101, 106, 103], + "<|im_end|>": [104], + ">>>A": [200, 201], + ">>>U": [200, 202], + "<<<": [210], + "<|image_pad|>": [250], + } + strategy = Qwen2VLProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + {"role": "assistant", "start": ">>>A", "end": "<<<"}, + {"role": "user", "start": ">>>U", "end": "<<<"}, + ], + ) + seq = [ + 101, + 106, + 103, + 7, + 104, + 200, + 201, + 9, + 9, + 210, + ] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, -100, -100, -100, -100, -100, 9, 9, 210] + + +def test_role_boundaries_override_enables_unverified_strategy(): + """Override lets users opt in to role masking on strategies that default opt out.""" + vocab = { + "BOA": [50, 51], + "EOT": [60], + } + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + {"role": "assistant", "start": "BOA", "end": "EOT"}, + ], + ) + seq = [1, 2, 3, 50, 51, 7, 8, 60, 9] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, -100, -100, -100, 7, 8, 60, -100] + + +def test_role_boundaries_override_eos_token_sentinel(): + vocab = {"BOA": [50]} + tok = _Tokenizer(vocab, pad_id=0, eos_id=99) + strategy = ProcessingStrategy( + _Processor(tok), + role_boundaries_override=[ + {"role": "assistant", "start": "BOA", "end": "eos_token"}, + ], + ) + seq = [1, 50, 7, 7, 99, 2] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, 7, 7, 99, -100] + + +def test_role_boundaries_override_end_null_runs_to_sequence_end(): + vocab = {"BOA": [50]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + {"role": "assistant", "start": "BOA", "end": None}, + ], + ) + seq = [1, 2, 50, 7, 8, 9] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, -100, 7, 8, 9] + + +def test_role_boundaries_override_rejects_bad_spec(): + vocab = {"BOA": [50]} + with pytest.raises(ValueError, match="must have both 'role' and 'start'"): + ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[{"role": "assistant"}], + ) + + +def test_role_boundaries_override_rejects_unencodable_start(): + vocab = {"BOA": [50]} + with pytest.raises(ValueError, match="tokenizes to an empty sequence"): + ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + {"role": "assistant", "start": "MISSING", "end": None} + ], + ) + + +def test_role_boundaries_override_rejects_unencodable_end(): + vocab = {"BOA": [50]} + with pytest.raises(ValueError, match="tokenizes to an empty sequence"): + ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + {"role": "assistant", "start": "BOA", "end": "MISSING"} + ], + ) + + +def test_role_boundaries_override_accepts_pydantic_models(): + # cfg.role_boundaries arrives as RoleBoundarySpec after pydantic parsing. + from axolotl.utils.schemas.multimodal import RoleBoundarySpec + + vocab = {"BOA": [50], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + RoleBoundarySpec(role="assistant", start="BOA", end="EOT") + ], + ) + assert len(strategy.role_boundaries) == 1 + assert strategy.role_boundaries[0].role == "assistant" + assert strategy.role_boundaries[0].start_tokens == [50] + assert strategy.role_boundaries[0].end_tokens == [60] + + +def test_base_strategy_warns_when_no_boundaries(axolotl_caplog): + """No boundaries + train_on_inputs=False: one-shot warning, labels unchanged.""" + import axolotl.processing_strategies as mod + + mod._ROLE_MASK_WARNED.discard("ProcessingStrategy") + + vocab = {"dummy": [1]} + s = ProcessingStrategy(_Processor(_Tokenizer(vocab, pad_id=0))) + + with axolotl_caplog.at_level( + logging.WARNING, logger="axolotl.processing_strategies" + ): + labels = s.process_labels(torch.tensor([[1, 2, 3]])) + assert labels.tolist() == [[1, 2, 3]] + assert any("role boundaries" in rec.message for rec in axolotl_caplog.records) + + +# --------------------------------------------------------------------------- # +# Additional edge-case coverage +# --------------------------------------------------------------------------- # + + +def test_scanner_batch_size_greater_than_one(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]), + ] + labels = torch.tensor( + [ + [1, 3, 7, 9, 1, 2, 8, 9], + [1, 2, 5, 5, 9, 0, 0, 0], + ] + ) + out = _apply_role_boundaries(labels, boundaries, {"assistant"}, "turn").tolist() + assert out[0] == [-100, -100, -100, -100, -100, -100, 8, 9] + assert out[1] == [-100, -100, 5, 5, 9, -100, -100, -100] + + +def test_scanner_adjacent_trainable_turns(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + ] + seq = [1, 2, 5, 9, 1, 2, 6, 9] + out = _scan(boundaries, seq) + assert out == [-100, -100, 5, 9, -100, -100, 6, 9] + + +def test_scanner_train_on_eos_none_multi_turn(): + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]), + ] + seq = [1, 3, 7, 9, 1, 2, 8, 9, 1, 3, 7, 9, 1, 2, 6, 9] + out = _scan(boundaries, seq, train_on_eos="none") + assert out == [ + -100, + -100, + -100, + -100, + -100, + -100, + 8, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 6, + -100, + ] + + +def test_scanner_train_on_eos_all_with_user_turn_no_end_marker(): + """Unclosed non-trainable span with train_on_eos='all': nothing included, no crash.""" + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]), + RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]), + ] + seq = [1, 3, 7, 7, 7] + out = _scan(boundaries, seq, train_on_eos="all") + assert out == [-100, -100, -100, -100, -100] + + +def test_scanner_include_start_true_via_override(): + vocab = {"BOA": [50, 51], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + { + "role": "assistant", + "start": "BOA", + "end": "EOT", + "include_start": True, + }, + ], + ) + seq = [1, 50, 51, 7, 8, 60, 9] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, 50, 51, 7, 8, 60, -100] + + +def test_scanner_include_end_false_via_override(): + """include_end=False drops end marker even with train_on_eos='turn'.""" + vocab = {"BOA": [50], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + { + "role": "assistant", + "start": "BOA", + "end": "EOT", + "include_end": False, + }, + ], + ) + seq = [1, 50, 7, 8, 60, 9] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, 7, 8, -100, -100] + + +def test_scanner_empty_start_tokens_is_defensive_noop(): + """Defensive: empty start_tokens matches nothing; everything masked.""" + boundaries = [ + RoleBoundary(role="assistant", start_tokens=[], end_tokens=[9]), + ] + seq = [1, 2, 3, 4, 9] + out = _scan(boundaries, seq) + assert out == [-100] * 5 + + +def test_process_labels_masks_pad_inside_assistant_span(): + """Pad inside a trainable span is still masked post-scan.""" + strategy = _make_qwen2vl() + seq = [101, 102, 103, 8, 0, 8, 104] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, -100, 8, -100, 8, 104] + + +def test_process_labels_all_pad_sequence_does_not_crash(): + strategy = _make_qwen2vl() + seq = [0, 0, 0, 0] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, -100, -100] + + +def test_qwen2vl_multiple_consecutive_assistant_turns(): + strategy = _make_qwen2vl() + seq = [101, 102, 103, 8, 104, 101, 102, 103, 9, 104] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [ + -100, + -100, + -100, + 8, + 104, + -100, + -100, + -100, + 9, + 104, + ] + + +def test_qwen2vl_batch_of_two_rows(): + strategy = _make_qwen2vl() + row_a = [101, 106, 103, 7, 104, 101, 102, 103, 8, 104] + row_b = [101, 102, 103, 9, 104, 0, 0, 0, 0, 0] + out = strategy.process_labels(torch.tensor([row_a, row_b])).tolist() + assert out[0] == [-100, -100, -100, -100, -100, -100, -100, -100, 8, 104] + assert out[1] == [-100, -100, -100, 9, 104, -100, -100, -100, -100, -100] + + +def test_qwen3_5_train_on_inputs_true_still_masks_video_pad(): + """train_on_inputs=True skips role masking but media tokens are still masked.""" + tok = _qwen_tokenizer() + strategy = Qwen3_5ProcessingStrategy(_Processor(tok), train_on_inputs=True) + seq = [101, 106, 103, 201, 7, 104, 101, 102, 103, 201, 8, 104] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + expected = list(seq) + expected[3] = -100 + expected[9] = -100 + assert out == expected + + +def test_role_boundaries_override_role_not_in_roles_to_train(): + """Override covering only a non-trainable role masks everything.""" + vocab = {"BOU": [50], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + {"role": "user", "start": "BOU", "end": "EOT"}, + ], + ) + seq = [1, 50, 7, 8, 60, 9] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100] * 6 + + +def test_role_boundaries_override_include_start_flag_round_trips(): + from axolotl.utils.schemas.multimodal import RoleBoundarySpec + + vocab = {"BOA": [50], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=[ + RoleBoundarySpec( + role="assistant", start="BOA", end="EOT", include_start=True + ), + ], + ) + assert len(strategy.role_boundaries) == 1 + assert strategy.role_boundaries[0].include_start is True + assert strategy.role_boundaries[0].include_end is True + + +def test_multimodal_config_parses_dict_role_boundaries_to_specs(): + from axolotl.utils.schemas.multimodal import ( + MultiModalConfig, + RoleBoundarySpec, + ) + + cfg = MultiModalConfig( + role_boundaries=[ + {"role": "assistant", "start": "BOA", "end": "EOT"}, + {"role": "user", "start": "BOU", "end": "EOT"}, + ] + ) + assert cfg.role_boundaries is not None + assert len(cfg.role_boundaries) == 2 + assert all(isinstance(rb, RoleBoundarySpec) for rb in cfg.role_boundaries) + + vocab = {"BOA": [50], "BOU": [51], "EOT": [60]} + strategy = ProcessingStrategy( + _Processor(_Tokenizer(vocab, pad_id=0)), + role_boundaries_override=cfg.role_boundaries, + ) + seq = [51, 7, 60, 50, 8, 60] + out = strategy.process_labels(torch.tensor([seq])).tolist()[0] + assert out == [-100, -100, -100, -100, 8, 60]