Files
axolotl/tests/test_processing_strategies.py
thad0ctor 5352d41d32 feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries` (#3625)
* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs+types: address CodeRabbit nitpicks on PR #7

- builders/causal.py: add inline NOTE that multi-dataset configs reuse
  the first dataset's masking knobs (roles_to_train / train_on_eos) for
  all datasets — heterogeneous per-dataset overrides are not supported
  in the MM path today.
- processing_strategies.py: annotate inner scanner helpers
  _match_prefix and _find_end with explicit types (Tensor, int,
  list[int] → bool / tuple[int, bool]) for readability.
- docs/multimodal_assistant_mask.md: renumber the "Commits on this
  branch" list to 1-7 consecutive (previously skipped 3).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(mm-mask): address two CodeRabbit findings on PR #7

1. Schema rejected `train_on_eos: "none"` despite the scanner honoring it.
   `_VALID_TRAIN_ON_EOS` accepts "none" and the design doc lists it, but
   `SFTDataset.train_on_eos` was `Literal["all", "turn", "last"]`, so YAML
   users hit a pydantic ValidationError at config load. Added "none" to
   the Literal and updated the description.

2. `cfg.role_boundaries: []` had split-personality semantics: the strategy
   ctor treated it as "replace built-ins with empty" while the collator
   plumbing treated it as "unset", and both the design doc and the
   MultiModalConfig schema help text promised wholesale replacement for
   any set value. Aligned on opt-in semantics across all four surfaces —
   a non-empty list replaces built-ins wholesale; unset or `[]` falls back
   to built-ins. Rationale: honoring `[]` literally yields all-masked
   labels and zero gradient, which is almost always a typo or leftover
   rather than a deliberate user action. Users who want to disable role
   masking should unset the field or use `train_on_inputs: true`.

   Also sharpened the fallback one-shot warning for strategies without
   built-in boundaries: names the consequence ("only pad and media tokens
   are masked, every other token contributes to loss") and points users
   at `cfg.role_boundaries` + docs/multimodal_assistant_mask.md instead
   of "see axolotl/processing_strategies.py for how to declare
   boundaries."

Files:
- src/axolotl/utils/schemas/datasets.py: Literal adds "none"
- src/axolotl/processing_strategies.py: ctor truthiness check on
  role_boundaries_override; sharpened fallback warning
- src/axolotl/utils/schemas/multimodal.py: role_boundaries description
  now calls out opt-in + empty-list fallback semantics
- docs/multimodal_assistant_mask.md: same clarification in the Semantics
  block; updated the fallback-path detection paragraph to quote the new
  warning text
- tests/test_processing_strategies.py: +2 regressions
  (test_sft_dataset_schema_accepts_all_supported_train_on_eos_values,
  test_empty_role_boundaries_override_falls_back_to_builtin); 63/63 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* doc cleanup

* fix(mm-mask): CodeRabbit findings + lint fix on PR #3625

Pre-commit failure: trailing newline missing on
docs/multimodal_assistant_mask.md (end-of-file-fixer hook).

Six CodeRabbit findings addressed:

1. Scanner: non-trainable role's end marker ignored ``include_end``.
   Under ``train_on_eos="all"``, the shared ``[/INST]`` token (user-end
   with ``include_end=False``, intentionally re-matched as assistant-start)
   leaked into loss via the user branch on Pixtral / Mistral V7 Tekken.
   Fix: gate the non-trainable branch on ``best_match.include_end`` to
   mirror the trainable branch.

2. Gemma3 ``boi_token`` lookup used ``tokenizer.special_tokens_map.get("boi_token")``,
   which never fires on real checkpoints (``special_tokens_map`` only
   holds HF's standard slots — bos/eos/pad/unk/...). Swap to direct
   attribute read ``getattr(tokenizer, "boi_token", None)``, matching
   what ``transformers.models.gemma3.processing_gemma3`` itself does.
   Updated the ``_gemma_tokenizer`` test fixture to mirror real-model
   shape so the test exercises the production code path.

3. GLM dispatcher only registered ``Glm46VProcessor`` (GLM-4.6V /
   GLM-4.7V). Real ``Glm4vProcessor`` (GLM-4V / GLM-4.1V) users fell
   through to the base fallback. Both processors ship identical
   media-token markers, so register both under the shared
   ``Glm4vProcessingStrategy`` with independent try/except import blocks.
   Updated class docstring. +2 dispatcher regressions.

4. Gemma3 ``process_labels`` hardcoded 262144 for the soft image token.
   Resolve dynamically via ``tokenizer.convert_tokens_to_ids("<image_soft_token>")``
   with unk-id guard; fall back to 262144 only if the string isn't in
   vocab. Mirrors ``Gemma4ProcessingStrategy.process_labels`` pattern.

5. ``build_collator`` was called twice per ``build()`` (eval + train
   passes), producing two identical ``MM collator: ...`` INFO banners on
   startup. Gate the log on ``is_eval=False`` so only the training pass
   emits it.

6. Removed unused ``_mistral_common_stub`` pytest fixture (13 refs → 0,
   always returned ``None``; the dispatcher already handles missing
   ``mistral_common`` via lazy import + ``try/except``). Added
   ``test_scanner_train_on_eos_all_with_non_trainable_include_end_false``
   — a focused scanner-level lock-in for finding #1, independent of any
   specific VLM strategy.

Test count: 63 → 68 passing. Local ``pre-commit run --all-files`` green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* chore(mm-mask): hoist .tolist() out of scanner; shorten comments/docstrings

- Scanner perf: convert labels[i] to a Python list once per row so
  _match_prefix / _find_end operate on list slices instead of
  re-materializing Tensor slices via .tolist() on every probe. Cuts
  O(n*boundaries) CPython↔C boundary crossings per batch.
- Markdown lint (MD001, MD040): promote two h3 section headings to h2
  under the h1; add `text` language to the verify-at-runtime fenced block.
- Shorten verbose comments/docstrings added in recent commits to
  bare-minimum "why" notes matching the repo's existing style.

68/68 tests, 8/8 pre-commit hooks still pass.
2026-05-05 11:25:39 -04:00

1165 lines
39 KiB
Python

"""Tests for ``axolotl.processing_strategies`` using fake tokenizers (offline/CI-safe)."""
import logging
import pytest
import torch
from pydantic import ValidationError
from axolotl.processing_strategies import (
Gemma3nProcessingStrategy,
Gemma3ProcessingStrategy,
Gemma4ProcessingStrategy,
Llama3_2VisionProcessingStrategy,
Llama4ProcessingStrategy,
MistralV7TekkenProcessingStrategy,
PixtralProcessingStrategy,
ProcessingStrategy,
Qwen2VLProcessingStrategy,
Qwen3_5ProcessingStrategy,
RoleBoundary,
_apply_role_boundaries,
get_processing_strategy,
)
@pytest.fixture
def axolotl_caplog(caplog):
"""caplog that also captures records from the ``axolotl`` logger.
The axolotl logger sets ``propagate=False`` once ``configure_logging()`` is
called (which happens indirectly in many CI test paths), so the default
caplog handler installed on the root logger never sees these records.
Attaching ``caplog.handler`` to ``axolotl.processing_strategies`` directly
makes assertions reliable regardless of whether ``configure_logging`` has
already run on this worker.
"""
logger = logging.getLogger("axolotl.processing_strategies")
logger.addHandler(caplog.handler)
previous_level = logger.level
logger.setLevel(logging.DEBUG)
try:
yield caplog
finally:
logger.removeHandler(caplog.handler)
logger.setLevel(previous_level)
# --------------------------------------------------------------------------- #
# Generic fake tokenizer/processor scaffold
# --------------------------------------------------------------------------- #
class _Tokenizer:
"""Minimal tokenizer stub; ``vocab`` maps marker strings to their id lists."""
def __init__(
self,
vocab: dict[str, list[int]],
pad_id: int = 0,
unk_id: int = 3,
eos_id: int | None = None,
):
self.vocab = vocab
self._reverse = {}
for tok, ids in vocab.items():
if len(ids) == 1:
self._reverse[ids[0]] = tok
self.pad_token_id = pad_id
self.unk_token_id = unk_id
if eos_id is not None:
self.eos_token_id = eos_id
def encode(self, text, add_special_tokens=False):
# Unknown markers return [] so _encode_markers drops them silently.
return list(self.vocab.get(text, []))
def convert_tokens_to_ids(self, token):
v = self.vocab.get(token)
if v is None:
return self.unk_token_id
return v[0] if len(v) == 1 else self.unk_token_id
class _Processor:
def __init__(self, tokenizer: _Tokenizer):
self.tokenizer = tokenizer
# --------------------------------------------------------------------------- #
# Base scanner tests (train_on_inputs / roles_to_train / train_on_eos)
# --------------------------------------------------------------------------- #
def _scan(role_boundaries, seq, roles_to_train=("assistant",), train_on_eos="turn"):
labels = torch.tensor([seq])
return _apply_role_boundaries(
labels, role_boundaries, set(roles_to_train), train_on_eos
).tolist()[0]
def test_scanner_assistant_only_basic():
boundaries = [
RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]),
RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]),
]
seq = [1, 3, 7, 9, 1, 2, 8, 8, 9, 5]
out = _scan(boundaries, seq)
assert out == [-100, -100, -100, -100, -100, -100, 8, 8, 9, -100]
def test_scanner_train_on_eos_none_excludes_end_marker():
boundaries = [
RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]),
]
seq = [1, 2, 8, 8, 9]
out = _scan(boundaries, seq, train_on_eos="none")
assert out == [-100, -100, 8, 8, -100]
def test_scanner_train_on_eos_all_keeps_non_assistant_end_marker():
boundaries = [
RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]),
RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]),
]
seq = [1, 3, 7, 9, 1, 2, 8, 9]
out = _scan(boundaries, seq, train_on_eos="all")
assert out == [-100, -100, -100, 9, -100, -100, 8, 9]
def test_scanner_train_on_eos_all_with_non_trainable_include_end_false():
"""Non-trainable + include_end=False must not leak end marker on 'all'."""
boundaries = [
RoleBoundary(
role="user",
start_tokens=[50],
end_tokens=[51],
include_end=False, # shared with assistant-start
),
RoleBoundary(
role="assistant",
start_tokens=[51],
end_tokens=[99], # eos
),
]
seq = [50, 7, 51, 8, 8, 99] # [INST] 7 [/INST] 8 8 EOS
out = _scan(boundaries, seq, roles_to_train=("assistant",), train_on_eos="all")
# [/INST] at idx 2 must stay masked — user.include_end=False says so.
assert out == [-100, -100, -100, 8, 8, 99]
def test_scanner_roles_to_train_user_and_assistant():
boundaries = [
RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]),
RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]),
]
seq = [1, 3, 7, 9, 1, 2, 8, 9]
out = _scan(boundaries, seq, roles_to_train=("user", "assistant"))
# include_start defaults to False so role-start markers stay masked.
assert out == [-100, -100, 7, 9, -100, -100, 8, 9]
def test_scanner_truncated_assistant():
"""Missing end marker: span runs to end-of-sequence, end marker not emitted."""
boundaries = [
RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]),
]
seq = [1, 2, 8, 8, 8]
out = _scan(boundaries, seq)
assert out == [-100, -100, 8, 8, 8]
def test_scanner_longest_prefix_wins():
boundaries = [
RoleBoundary(role="assistant", start_tokens=[1, 2, 4], end_tokens=[9]),
RoleBoundary(role="user", start_tokens=[1, 2], end_tokens=[9]),
]
seq = [1, 2, 4, 8, 9]
out = _scan(boundaries, seq)
assert out == [-100, -100, -100, 8, 9]
def test_scanner_no_boundaries_masks_everything():
# Strategies short-circuit this in _mask_non_assistant; see test_base_strategy_warns_when_no_boundaries.
labels = torch.tensor([[1, 2, 3, 4]])
out = _apply_role_boundaries(labels, [], {"assistant"}, "turn")
assert out.tolist() == [[-100, -100, -100, -100]]
def test_scanner_train_on_eos_last_only_final_trainable_turn():
boundaries = [
RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]),
]
seq = [1, 2, 5, 9, 1, 2, 6, 9]
out = _scan(boundaries, seq, train_on_eos="last")
# Only the second assistant turn's end marker (index 7) is kept.
assert out == [-100, -100, 5, -100, -100, -100, 6, 9]
def test_scanner_train_on_eos_last_no_trainable_turn_is_noop():
boundaries = [
RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]),
]
seq = [1, 3, 5, 9, 1, 3, 6, 9]
out = _scan(boundaries, seq, roles_to_train=("assistant",), train_on_eos="last")
assert out == [-100] * 8
def test_strategy_rejects_unknown_train_on_eos():
vocab = {"BOA": [50], "EOT": [60]}
with pytest.raises(ValueError, match="train_on_eos"):
ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
train_on_eos="bogus",
role_boundaries_override=[
{"role": "assistant", "start": "BOA", "end": "EOT"}
],
)
def test_strategy_accepts_all_supported_train_on_eos_values():
vocab = {"BOA": [50], "EOT": [60]}
for val in ("turn", "all", "none", "last"):
ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
train_on_eos=val,
role_boundaries_override=[
{"role": "assistant", "start": "BOA", "end": "EOT"}
],
)
def test_empty_role_boundaries_override_falls_back_to_builtin():
"""Empty override must fall through to built-ins (opt-in semantics)."""
vocab = {
"<|im_start|>assistant\n": [101, 102, 103],
"<|im_start|>user\n": [101, 106, 103],
"<|im_end|>": [104],
}
strat_empty = Qwen2VLProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=[],
)
strat_default = Qwen2VLProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
)
# Empty override === no override: both strategies keep the built-in boundaries.
assert strat_empty.role_boundaries == strat_default.role_boundaries
assert len(strat_empty.role_boundaries) > 0 # sanity: built-ins are non-empty
def test_sft_dataset_schema_accepts_all_supported_train_on_eos_values():
"""SFTDataset.train_on_eos accepts every value the scanner honors."""
from axolotl.utils.schemas.datasets import SFTDataset
for val in ("all", "turn", "last", "none"):
ds = SFTDataset(path="dummy", type="chat_template", train_on_eos=val)
assert ds.train_on_eos == val
with pytest.raises(ValidationError):
SFTDataset(path="dummy", type="chat_template", train_on_eos="bogus")
def test_strategy_init_logs_resolved_masking_config_builtin(axolotl_caplog):
vocab = {
"<|im_start|>assistant\n": [101, 102, 103],
"<|im_start|>user\n": [101, 106, 103],
"<|im_end|>": [104],
}
with axolotl_caplog.at_level(logging.INFO, logger="axolotl.processing_strategies"):
Qwen2VLProcessingStrategy(_Processor(_Tokenizer(vocab, pad_id=0)))
msgs = [r.getMessage() for r in axolotl_caplog.records]
assert any(
"ProcessingStrategy init" in m
and "Qwen2VLProcessingStrategy" in m
and "boundaries_source=built-in" in m
for m in msgs
)
def test_strategy_init_logs_resolved_masking_config_override(axolotl_caplog):
vocab = {"BOA": [50, 51], "EOT": [60]}
with axolotl_caplog.at_level(logging.INFO, logger="axolotl.processing_strategies"):
ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=[
{"role": "assistant", "start": "BOA", "end": "EOT"},
],
)
msgs = [r.getMessage() for r in axolotl_caplog.records]
# Resolved start/end ids must appear in the log so users can verify what
# was actually matched.
assert any(
"ProcessingStrategy init" in m
and "boundaries_source=override" in m
and "[50, 51]" in m
and "[60]" in m
for m in msgs
)
def test_process_labels_no_warning_when_image_token_id_none():
"""image_token_id=None must not trigger a UserWarning from ``labels == None``."""
import warnings
vocab = {"BOA": [50], "EOT": [60]}
strategy = ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=[{"role": "assistant", "start": "BOA", "end": "EOT"}],
)
assert strategy.image_token_id is None
with warnings.catch_warnings():
warnings.simplefilter("error")
strategy.process_labels(torch.tensor([[1, 50, 2, 3, 60]]))
def test_roles_to_train_empty_list_masks_everything():
"""An explicit empty list is distinct from None and disables all roles."""
vocab = {"BOA": [50], "EOT": [60]}
strategy = ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
roles_to_train=[],
role_boundaries_override=[{"role": "assistant", "start": "BOA", "end": "EOT"}],
)
assert strategy.roles_to_train == []
seq = [1, 50, 7, 8, 60, 9]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [-100] * 6
# --------------------------------------------------------------------------- #
# Qwen2VL / Qwen3.5
# --------------------------------------------------------------------------- #
def _qwen_tokenizer():
# ChatML-ish with image_pad=200, video_pad=201.
vocab = {
"<|im_start|>assistant\n": [101, 102, 103],
"<|im_start|>user\n": [101, 106, 103],
"<|im_start|>system\n": [101, 105, 103],
"<|im_end|>": [104],
"<|image_pad|>": [200],
"<|video_pad|>": [201],
}
return _Tokenizer(vocab, pad_id=0)
def _make_qwen2vl():
tok = _qwen_tokenizer()
return Qwen2VLProcessingStrategy(_Processor(tok))
def test_qwen2vl_masks_user_keeps_assistant_and_image_pad():
strategy = _make_qwen2vl()
seq = [
101,
105,
103,
77,
104,
101,
106,
103,
7,
104,
101,
102,
103,
200,
8,
104,
]
labels = strategy.process_labels(torch.tensor([seq]))
out = labels.tolist()[0]
assert out[:10] == [-100] * 10
assert out[10] == -100 and out[11] == -100 and out[12] == -100
assert out[13] == -100 # image_pad masked post-scan
assert out[14] == 8
assert out[15] == 104
def test_qwen3_5_masks_video_pad_too():
tok = _qwen_tokenizer()
strategy = Qwen3_5ProcessingStrategy(_Processor(tok))
seq = [101, 102, 103, 201, 8, 104]
labels = strategy.process_labels(torch.tensor([seq]))
assert labels.tolist()[0] == [-100, -100, -100, -100, 8, 104]
def test_qwen2vl_train_on_inputs_true_keeps_everything():
tok = _qwen_tokenizer()
strategy = Qwen2VLProcessingStrategy(_Processor(tok), train_on_inputs=True)
seq = [101, 106, 103, 7, 104, 101, 102, 103, 8, 104]
labels = strategy.process_labels(torch.tensor([seq]))
assert labels.tolist()[0] == seq
# --------------------------------------------------------------------------- #
# Gemma3 / Gemma3n
# --------------------------------------------------------------------------- #
def _gemma_tokenizer():
vocab = {
"<start_of_turn>model\n": [1, 2, 3],
"<start_of_turn>user\n": [1, 10, 3],
"<start_of_turn>system\n": [1, 11, 3],
"<end_of_turn>": [4],
"<start_of_image>": [50], # boi_token for Gemma3
}
tok = _Tokenizer(vocab, pad_id=0)
# boi_token is a direct tokenizer attribute on real Gemma3.
tok.boi_token = "<start_of_image>"
return tok
def test_gemma3_scanner_plus_soft_image_token():
strategy = Gemma3ProcessingStrategy(_Processor(_gemma_tokenizer()))
seq = [1, 10, 3, 7, 4, 1, 2, 3, 50, 8, 262144, 4]
labels = strategy.process_labels(torch.tensor([seq]))
# boi(50) and soft-image-token(262144) masked post-scan.
assert labels.tolist()[0] == [
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
8,
-100,
4,
]
def test_gemma3n_masks_image_and_audio_attrs():
tok = _gemma_tokenizer()
# Gemma3n exposes these as integer attrs on the tokenizer.
tok.image_token_id = 70
tok.audio_token_id = 71
tok.boi_token_id = 72
tok.eoi_token_id = 73
strategy = Gemma3nProcessingStrategy(_Processor(tok))
seq = [1, 2, 3, 70, 71, 72, 73, 9, 4]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [-100, -100, -100, -100, -100, -100, -100, 9, 4]
# --------------------------------------------------------------------------- #
# Gemma 4
# --------------------------------------------------------------------------- #
class _FakeGemma4Tokenizer(_Tokenizer):
"""Mirrors google/gemma-4-E2B-it token layout. Gemma4 role-start markers
include the trailing newline so the boundary matches the jinja template."""
VOCAB = {
"<|turn>model\n": [105, 4368, 108],
"<|turn>user\n": [105, 7777, 108],
"<|turn>system\n": [105, 8888, 108],
"<turn|>": [106],
"<|image|>": [258880],
"<|video|>": [258884],
"<|audio|>": [258881],
"<|image>": [255999],
"<image|>": [258882],
"<|audio>": [256000],
"<audio|>": [258883],
}
def __init__(self):
# Pass a fresh dict so per-instance mutations (should any future
# code path introduce them) cannot leak across tests via the
# shared class-level VOCAB.
super().__init__(
{token: list(ids) for token, ids in self.VOCAB.items()},
pad_id=0,
unk_id=3,
)
class _FakeGemma4Processor:
def __init__(self):
self.tokenizer = _FakeGemma4Tokenizer()
self.tokenizer.image_token_id = self.tokenizer.vocab["<|image|>"][0]
self.tokenizer.audio_token_id = self.tokenizer.vocab["<|audio|>"][0]
self.image_token = "<|image|>"
self.image_token_id = self.tokenizer.vocab["<|image|>"][0]
self.boi_token = "<|image>"
self.eoi_token = "<image|>"
self.video_token = "<|video|>"
self.video_token_id = self.tokenizer.vocab["<|video|>"][0]
self.audio_token = "<|audio|>"
self.audio_token_id = self.tokenizer.vocab["<|audio|>"][0]
self.boa_token = "<|audio>"
self.eoa_token = "<audio|>"
def test_gemma4_masks_everything_outside_assistant_span():
strategy = Gemma4ProcessingStrategy(_FakeGemma4Processor())
V = strategy.processor.tokenizer.vocab
user_start = V["<|turn>user\n"]
model_start = V["<|turn>model\n"]
turn_end = V["<turn|>"][0]
seq = [
0,
*user_start,
4444,
turn_end,
*model_start,
5555,
turn_end,
9999,
]
labels = strategy.process_labels(torch.tensor([seq]))
expected = [-100] * (1 + len(user_start) + 1 + 1 + len(model_start)) + [
5555,
turn_end,
-100,
]
assert labels.tolist()[0] == expected
def test_gemma4_masks_media_tokens_inside_assistant_span():
strategy = Gemma4ProcessingStrategy(_FakeGemma4Processor())
V = strategy.processor.tokenizer.vocab
model_start = V["<|turn>model\n"]
media = [
V["<|image|>"][0],
V["<|video|>"][0],
V["<|audio|>"][0],
V["<|image>"][0],
V["<image|>"][0],
V["<|audio>"][0],
V["<audio|>"][0],
]
turn_end = V["<turn|>"][0]
seq = [*model_start, *media, 9999, turn_end]
labels = strategy.process_labels(torch.tensor([seq]))
expected = [-100] * (len(model_start) + len(media)) + [9999, turn_end]
assert labels.tolist()[0] == expected
def test_gemma4_multiple_assistant_turns():
strategy = Gemma4ProcessingStrategy(_FakeGemma4Processor())
V = strategy.processor.tokenizer.vocab
turn_end = V["<turn|>"][0]
def user_turn(x):
return [*V["<|turn>user\n"], x, turn_end]
def model_turn(x):
return [*V["<|turn>model\n"], x, turn_end]
seq = user_turn(1111) + model_turn(2222) + user_turn(3333) + model_turn(4444)
labels = strategy.process_labels(torch.tensor([seq]))
kept = [t for t in labels.tolist()[0] if t != -100]
assert kept == [2222, turn_end, 4444, turn_end]
# --------------------------------------------------------------------------- #
# Llama 3.2 Vision / Llama 4
# --------------------------------------------------------------------------- #
def test_llama3_2_vision_assistant_masking():
vocab = {
"<|start_header_id|>assistant<|end_header_id|>\n\n": [1, 2, 3, 4, 5],
"<|start_header_id|>user<|end_header_id|>\n\n": [1, 2, 6, 4, 5],
"<|start_header_id|>system<|end_header_id|>\n\n": [1, 2, 7, 4, 5],
"<|start_header_id|>tool<|end_header_id|>\n\n": [1, 2, 8, 4, 5],
"<|start_header_id|>ipython<|end_header_id|>\n\n": [1, 2, 9, 4, 5],
"<|eot_id|>": [10],
}
strategy = Llama3_2VisionProcessingStrategy(_Processor(_Tokenizer(vocab, pad_id=0)))
seq = [1, 2, 6, 4, 5, 11, 10, 1, 2, 3, 4, 5, 12, 10]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [-100] * 12 + [12, 10]
def test_llama4_assistant_masking():
vocab = {
"<|header_start|>assistant<|header_end|>\n\n": [20, 21, 22, 23],
"<|header_start|>user<|header_end|>\n\n": [20, 21, 24, 23],
"<|header_start|>system<|header_end|>\n\n": [20, 21, 25, 23],
"<|header_start|>tool<|header_end|>\n\n": [20, 21, 26, 23],
"<|header_start|>ipython<|header_end|>\n\n": [20, 21, 27, 23],
"<|eot|>": [30],
}
strategy = Llama4ProcessingStrategy(_Processor(_Tokenizer(vocab, pad_id=0)))
seq = [20, 21, 24, 23, 100, 30, 20, 21, 22, 23, 200, 30]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [-100] * 10 + [200, 30]
# --------------------------------------------------------------------------- #
# Pixtral / Mistral v7 Tekken (eos-terminated assistant)
# --------------------------------------------------------------------------- #
def test_pixtral_assistant_terminates_at_eos():
# [/INST] is both user-end and assistant-start. Scanner backs up when
# user.include_end=False so the next iteration picks [/INST] up as
# assistant-start (Pixtral-specific handling in _build_role_boundaries).
vocab = {
"[INST]": [50],
"[/INST]": [51],
}
tok = _Tokenizer(vocab, pad_id=0, eos_id=99)
strategy = PixtralProcessingStrategy(_Processor(tok))
seq = [50, 7, 51, 8, 8, 99]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
# Full-sequence expectation: user span masked; assistant content + eos kept.
assert out == [-100, -100, -100, 8, 8, 99]
def test_mistral_v7_tekken_system_user_assistant():
vocab = {
"[SYSTEM_PROMPT]": [40],
"[/SYSTEM_PROMPT]": [41],
"[INST]": [50],
"[/INST]": [51],
}
tok = _Tokenizer(vocab, pad_id=0, eos_id=99)
strategy = MistralV7TekkenProcessingStrategy(_Processor(tok))
seq = [40, 5, 41, 50, 7, 51, 8, 99]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
# Full-sequence expectation: system + user spans masked; assistant kept.
assert out == [-100, -100, -100, -100, -100, -100, 8, 99]
def test_pixtral_train_on_eos_all_respects_user_include_end_false():
"""Pixtral [/INST] (user-end include_end=False) stays masked on 'all'."""
vocab = {"[INST]": [50], "[/INST]": [51]}
tok = _Tokenizer(vocab, pad_id=0, eos_id=99)
strategy = PixtralProcessingStrategy(_Processor(tok), train_on_eos="all")
seq = [50, 7, 51, 8, 8, 99]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
# [/INST] at idx 2 must stay masked — user.include_end=False says so.
# Assistant content (8, 8) + EOS (99) are unmasked as normal.
assert out == [-100, -100, -100, 8, 8, 99]
def test_mistral_v7_tekken_train_on_eos_all_respects_user_include_end_false():
"""System end (include_end=True) unmasked on 'all'; [/INST] stays masked."""
vocab = {
"[SYSTEM_PROMPT]": [40],
"[/SYSTEM_PROMPT]": [41],
"[INST]": [50],
"[/INST]": [51],
}
tok = _Tokenizer(vocab, pad_id=0, eos_id=99)
strategy = MistralV7TekkenProcessingStrategy(_Processor(tok), train_on_eos="all")
seq = [40, 5, 41, 50, 7, 51, 8, 99]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
# system content masked, [/SYSTEM_PROMPT]=41 kept (include_end=True + all);
# user + [/INST]=51 masked (include_end=False); assistant 8 + eos 99 kept.
assert out == [-100, -100, 41, -100, -100, -100, 8, 99]
# --------------------------------------------------------------------------- #
# Dispatcher routing
# --------------------------------------------------------------------------- #
def _dispatch(processor, chat_template_type):
return get_processing_strategy(
processor=processor,
chat_template=None,
chat_template_type=chat_template_type,
)
def test_dispatch_qwen2_vl():
s = _dispatch(_Processor(_qwen_tokenizer()), "qwen2_vl")
assert isinstance(s, Qwen2VLProcessingStrategy)
def test_dispatch_qwen3_5():
s = _dispatch(_Processor(_qwen_tokenizer()), "qwen3_5")
assert isinstance(s, Qwen3_5ProcessingStrategy)
def test_dispatch_gemma3():
s = _dispatch(_Processor(_gemma_tokenizer()), "gemma3")
assert isinstance(s, Gemma3ProcessingStrategy)
def test_dispatch_gemma3n():
s = _dispatch(_Processor(_gemma_tokenizer()), "gemma3n")
assert isinstance(s, Gemma3nProcessingStrategy)
def test_dispatch_gemma4():
s = _dispatch(_FakeGemma4Processor(), "gemma4")
assert isinstance(s, Gemma4ProcessingStrategy)
def test_dispatch_llama3_2_vision():
vocab = {
"<|start_header_id|>assistant<|end_header_id|>\n\n": [1, 2, 3, 4, 5],
"<|eot_id|>": [10],
}
s = _dispatch(_Processor(_Tokenizer(vocab, pad_id=0)), "llama3_2_vision")
assert isinstance(s, Llama3_2VisionProcessingStrategy)
def test_dispatch_llama4():
vocab = {
"<|header_start|>assistant<|header_end|>\n\n": [20, 21, 22, 23],
"<|eot|>": [30],
}
s = _dispatch(_Processor(_Tokenizer(vocab, pad_id=0)), "llama4")
assert isinstance(s, Llama4ProcessingStrategy)
def test_dispatch_pixtral():
vocab = {"[INST]": [50], "[/INST]": [51]}
s = _dispatch(_Processor(_Tokenizer(vocab, pad_id=0, eos_id=99)), "pixtral")
assert isinstance(s, PixtralProcessingStrategy)
def test_dispatch_mistral_v7_tekken():
vocab = {
"[INST]": [50],
"[/INST]": [51],
"[SYSTEM_PROMPT]": [40],
"[/SYSTEM_PROMPT]": [41],
}
s = _dispatch(
_Processor(_Tokenizer(vocab, pad_id=0, eos_id=99)), "mistral_v7_tekken"
)
assert isinstance(s, MistralV7TekkenProcessingStrategy)
def test_dispatch_unknown_falls_back_to_base():
vocab = {"dummy": [1]}
s = _dispatch(_Processor(_Tokenizer(vocab, pad_id=0)), "llava")
assert type(s) is ProcessingStrategy
def _glm_vision_processor(cls_path):
"""Spec'd MagicMock so isinstance(mock, cls) passes without real HF files."""
from importlib import import_module
from unittest.mock import MagicMock
mod_name, cls_name = cls_path.rsplit(".", 1)
cls = getattr(import_module(mod_name), cls_name)
vocab = {
"<|image|>": [200],
"<|begin_of_image|>": [201],
"<|end_of_image|>": [202],
"<|video|>": [210],
"<|begin_of_video|>": [211],
"<|end_of_video|>": [212],
}
tok = _Tokenizer(vocab, pad_id=0)
proc = MagicMock(spec=cls)
proc.tokenizer = tok
# Drop processor.image_token so base class skips its probe.
del proc.image_token
return proc
def test_dispatch_glm4v_via_Glm4vProcessor():
"""Glm4vProcessor (GLM-4V) routes to Glm4vProcessingStrategy."""
pytest.importorskip("transformers.models.glm4v.processing_glm4v")
from axolotl.processing_strategies import Glm4vProcessingStrategy
proc = _glm_vision_processor(
"transformers.models.glm4v.processing_glm4v.Glm4vProcessor"
)
s = _dispatch(proc, None)
assert isinstance(s, Glm4vProcessingStrategy)
def test_dispatch_glm4v_via_Glm46VProcessor():
"""Glm46VProcessor (GLM-4.6V) also routes to Glm4vProcessingStrategy."""
pytest.importorskip("transformers.models.glm46v.processing_glm46v")
from axolotl.processing_strategies import Glm4vProcessingStrategy
proc = _glm_vision_processor(
"transformers.models.glm46v.processing_glm46v.Glm46VProcessor"
)
s = _dispatch(proc, None)
assert isinstance(s, Glm4vProcessingStrategy)
# --------------------------------------------------------------------------- #
# Config-based role-boundary override
# --------------------------------------------------------------------------- #
def test_role_boundaries_override_replaces_built_in():
"""Override swaps the built-in boundaries wholesale, not additively."""
vocab = {
"<|im_start|>assistant\n": [101, 102, 103],
"<|im_start|>user\n": [101, 106, 103],
"<|im_end|>": [104],
">>>A": [200, 201],
">>>U": [200, 202],
"<<<": [210],
"<|image_pad|>": [250],
}
strategy = Qwen2VLProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=[
{"role": "assistant", "start": ">>>A", "end": "<<<"},
{"role": "user", "start": ">>>U", "end": "<<<"},
],
)
seq = [
101,
106,
103,
7,
104,
200,
201,
9,
9,
210,
]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [-100, -100, -100, -100, -100, -100, -100, 9, 9, 210]
def test_role_boundaries_override_enables_unverified_strategy():
"""Override lets users opt in to role masking on strategies that default opt out."""
vocab = {
"BOA": [50, 51],
"EOT": [60],
}
strategy = ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=[
{"role": "assistant", "start": "BOA", "end": "EOT"},
],
)
seq = [1, 2, 3, 50, 51, 7, 8, 60, 9]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [-100, -100, -100, -100, -100, 7, 8, 60, -100]
def test_role_boundaries_override_eos_token_sentinel():
vocab = {"BOA": [50]}
tok = _Tokenizer(vocab, pad_id=0, eos_id=99)
strategy = ProcessingStrategy(
_Processor(tok),
role_boundaries_override=[
{"role": "assistant", "start": "BOA", "end": "eos_token"},
],
)
seq = [1, 50, 7, 7, 99, 2]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [-100, -100, 7, 7, 99, -100]
def test_role_boundaries_override_end_null_runs_to_sequence_end():
vocab = {"BOA": [50]}
strategy = ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=[
{"role": "assistant", "start": "BOA", "end": None},
],
)
seq = [1, 2, 50, 7, 8, 9]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [-100, -100, -100, 7, 8, 9]
def test_role_boundaries_override_rejects_bad_spec():
vocab = {"BOA": [50]}
with pytest.raises(ValueError, match="must have both 'role' and 'start'"):
ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=[{"role": "assistant"}],
)
def test_role_boundaries_override_rejects_unencodable_start():
vocab = {"BOA": [50]}
with pytest.raises(ValueError, match="tokenizes to an empty sequence"):
ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=[
{"role": "assistant", "start": "MISSING", "end": None}
],
)
def test_role_boundaries_override_rejects_unencodable_end():
vocab = {"BOA": [50]}
with pytest.raises(ValueError, match="tokenizes to an empty sequence"):
ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=[
{"role": "assistant", "start": "BOA", "end": "MISSING"}
],
)
def test_role_boundaries_override_accepts_pydantic_models():
# cfg.role_boundaries arrives as RoleBoundarySpec after pydantic parsing.
from axolotl.utils.schemas.multimodal import RoleBoundarySpec
vocab = {"BOA": [50], "EOT": [60]}
strategy = ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=[
RoleBoundarySpec(role="assistant", start="BOA", end="EOT")
],
)
assert len(strategy.role_boundaries) == 1
assert strategy.role_boundaries[0].role == "assistant"
assert strategy.role_boundaries[0].start_tokens == [50]
assert strategy.role_boundaries[0].end_tokens == [60]
def test_base_strategy_warns_when_no_boundaries(axolotl_caplog):
"""No boundaries + train_on_inputs=False: one-shot warning, labels unchanged."""
import axolotl.processing_strategies as mod
mod._ROLE_MASK_WARNED.discard("ProcessingStrategy")
vocab = {"dummy": [1]}
s = ProcessingStrategy(_Processor(_Tokenizer(vocab, pad_id=0)))
with axolotl_caplog.at_level(
logging.WARNING, logger="axolotl.processing_strategies"
):
labels = s.process_labels(torch.tensor([[1, 2, 3]]))
assert labels.tolist() == [[1, 2, 3]]
assert any("role boundaries" in rec.message for rec in axolotl_caplog.records)
# --------------------------------------------------------------------------- #
# Additional edge-case coverage
# --------------------------------------------------------------------------- #
def test_scanner_batch_size_greater_than_one():
boundaries = [
RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]),
RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]),
]
labels = torch.tensor(
[
[1, 3, 7, 9, 1, 2, 8, 9],
[1, 2, 5, 5, 9, 0, 0, 0],
]
)
out = _apply_role_boundaries(labels, boundaries, {"assistant"}, "turn").tolist()
assert out[0] == [-100, -100, -100, -100, -100, -100, 8, 9]
assert out[1] == [-100, -100, 5, 5, 9, -100, -100, -100]
def test_scanner_adjacent_trainable_turns():
boundaries = [
RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]),
]
seq = [1, 2, 5, 9, 1, 2, 6, 9]
out = _scan(boundaries, seq)
assert out == [-100, -100, 5, 9, -100, -100, 6, 9]
def test_scanner_train_on_eos_none_multi_turn():
boundaries = [
RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]),
RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]),
]
seq = [1, 3, 7, 9, 1, 2, 8, 9, 1, 3, 7, 9, 1, 2, 6, 9]
out = _scan(boundaries, seq, train_on_eos="none")
assert out == [
-100,
-100,
-100,
-100,
-100,
-100,
8,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
6,
-100,
]
def test_scanner_train_on_eos_all_with_user_turn_no_end_marker():
"""Unclosed non-trainable span with train_on_eos='all': nothing included, no crash."""
boundaries = [
RoleBoundary(role="assistant", start_tokens=[1, 2], end_tokens=[9]),
RoleBoundary(role="user", start_tokens=[1, 3], end_tokens=[9]),
]
seq = [1, 3, 7, 7, 7]
out = _scan(boundaries, seq, train_on_eos="all")
assert out == [-100, -100, -100, -100, -100]
def test_scanner_include_start_true_via_override():
vocab = {"BOA": [50, 51], "EOT": [60]}
strategy = ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=[
{
"role": "assistant",
"start": "BOA",
"end": "EOT",
"include_start": True,
},
],
)
seq = [1, 50, 51, 7, 8, 60, 9]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [-100, 50, 51, 7, 8, 60, -100]
def test_scanner_include_end_false_via_override():
"""include_end=False drops end marker even with train_on_eos='turn'."""
vocab = {"BOA": [50], "EOT": [60]}
strategy = ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=[
{
"role": "assistant",
"start": "BOA",
"end": "EOT",
"include_end": False,
},
],
)
seq = [1, 50, 7, 8, 60, 9]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [-100, -100, 7, 8, -100, -100]
def test_scanner_empty_start_tokens_is_defensive_noop():
"""Defensive: empty start_tokens matches nothing; everything masked."""
boundaries = [
RoleBoundary(role="assistant", start_tokens=[], end_tokens=[9]),
]
seq = [1, 2, 3, 4, 9]
out = _scan(boundaries, seq)
assert out == [-100] * 5
def test_process_labels_masks_pad_inside_assistant_span():
"""Pad inside a trainable span is still masked post-scan."""
strategy = _make_qwen2vl()
seq = [101, 102, 103, 8, 0, 8, 104]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [-100, -100, -100, 8, -100, 8, 104]
def test_process_labels_all_pad_sequence_does_not_crash():
strategy = _make_qwen2vl()
seq = [0, 0, 0, 0]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [-100, -100, -100, -100]
def test_qwen2vl_multiple_consecutive_assistant_turns():
strategy = _make_qwen2vl()
seq = [101, 102, 103, 8, 104, 101, 102, 103, 9, 104]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [
-100,
-100,
-100,
8,
104,
-100,
-100,
-100,
9,
104,
]
def test_qwen2vl_batch_of_two_rows():
strategy = _make_qwen2vl()
row_a = [101, 106, 103, 7, 104, 101, 102, 103, 8, 104]
row_b = [101, 102, 103, 9, 104, 0, 0, 0, 0, 0]
out = strategy.process_labels(torch.tensor([row_a, row_b])).tolist()
assert out[0] == [-100, -100, -100, -100, -100, -100, -100, -100, 8, 104]
assert out[1] == [-100, -100, -100, 9, 104, -100, -100, -100, -100, -100]
def test_qwen3_5_train_on_inputs_true_still_masks_video_pad():
"""train_on_inputs=True skips role masking but media tokens are still masked."""
tok = _qwen_tokenizer()
strategy = Qwen3_5ProcessingStrategy(_Processor(tok), train_on_inputs=True)
seq = [101, 106, 103, 201, 7, 104, 101, 102, 103, 201, 8, 104]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
expected = list(seq)
expected[3] = -100
expected[9] = -100
assert out == expected
def test_role_boundaries_override_role_not_in_roles_to_train():
"""Override covering only a non-trainable role masks everything."""
vocab = {"BOU": [50], "EOT": [60]}
strategy = ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=[
{"role": "user", "start": "BOU", "end": "EOT"},
],
)
seq = [1, 50, 7, 8, 60, 9]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [-100] * 6
def test_role_boundaries_override_include_start_flag_round_trips():
from axolotl.utils.schemas.multimodal import RoleBoundarySpec
vocab = {"BOA": [50], "EOT": [60]}
strategy = ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=[
RoleBoundarySpec(
role="assistant", start="BOA", end="EOT", include_start=True
),
],
)
assert len(strategy.role_boundaries) == 1
assert strategy.role_boundaries[0].include_start is True
assert strategy.role_boundaries[0].include_end is True
def test_multimodal_config_parses_dict_role_boundaries_to_specs():
from axolotl.utils.schemas.multimodal import (
MultiModalConfig,
RoleBoundarySpec,
)
cfg = MultiModalConfig(
role_boundaries=[
{"role": "assistant", "start": "BOA", "end": "EOT"},
{"role": "user", "start": "BOU", "end": "EOT"},
]
)
assert cfg.role_boundaries is not None
assert len(cfg.role_boundaries) == 2
assert all(isinstance(rb, RoleBoundarySpec) for rb in cfg.role_boundaries)
vocab = {"BOA": [50], "BOU": [51], "EOT": [60]}
strategy = ProcessingStrategy(
_Processor(_Tokenizer(vocab, pad_id=0)),
role_boundaries_override=cfg.role_boundaries,
)
seq = [51, 7, 60, 50, 8, 60]
out = strategy.process_labels(torch.tensor([seq])).tolist()[0]
assert out == [-100, -100, -100, -100, 8, 60]