Files
axolotl/tests/test_mm_chat_collator.py
Wing Lian e4032fc90f Refactor separate attention flags with attn_implementation and capability/concerns feature flags (#3602)
* upgrade to torchao 0.17.0

* chore: lint

* refactor attention handling

* replace legacy attention boolean flags with capability properties

Replace checks with capability-based properties derived from attn_implementation

This separates three concerns that were conflated under flash_attention:
1. Backend selection -> attn_implementation enum
2. Packing capability -> attn_supports_packing property
3. Flash-attn library dependency -> attn_uses_flash_lib property

* compute attn capability flags in normalizer instead of properties

* make attn_implementation the single source of truth

* move attention-dependent validators to mode=after

* migrate remaining consumers to canonical attn_implementation

* expand attention tests + rewrite docs

* migrate example configs to canonical attn_implementation

* update doc snippets + reject gemma4-hybrid with non-FA2 backend

* remove dead gemma4 branch in _set_attention_config

* fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests

* drop "Phase 2" naming from attn-implementation tests

* regroup attn_implementation tests by feature concern

* clean up verbose comments and remove MD

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x

In transformers 5.x, ProcessorMixin.apply_chat_template gained its own
`return_dict` parameter (defaulting to False).  When return_dict=False
and tokenize=True the method returns out["input_ids"] directly — a 2-D
tensor — rather than the full BatchFeature dict.

The old code placed `return_dict=True` inside processor_kwargs.  In
transformers 5.x those kwargs are forwarded to the underlying processor
call self(...) where _merge_kwargs silently ignores any key not present
in MllamaProcessorKwargs (emitting a warning).  The outer return_dict
therefore stayed False, apply_chat_template returned the raw input_ids
tensor, and the subsequent `batch["input_ids"]` attempted to index a
2-D tensor with the 9-character string "input_ids", producing:

  IndexError: too many indices for tensor of dimension 2

The fix is to pass return_dict=True as a top-level keyword argument to
apply_chat_template (where it is actually consumed) and remove it from
processor_kwargs (where it was silently dropped).  No version guard is
needed: transformers is pinned to ==5.5.4 in pyproject.toml.

Adds a unit-level regression test (tests/test_mm_chat_collator.py) that
mocks the processor to return a raw tensor when apply_chat_template is
called without top-level return_dict=True, verifying the four invariants:
process_rows returns a dict, input_ids is 2-D, labels is 2-D, and
apply_chat_template receives return_dict=True as a top-level kwarg.

Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset
Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset
Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): process_rows returns dict (BatchFeature) shape

Two related changes for the multimodal chat collator under transformers 5.x:

1. Wrap apply_chat_template result in dict(...) so process_rows returns
   a plain dict rather than a BatchFeature instance. BatchFeature is a
   Mapping but not a dict; downstream code that did
     batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
   would index on a tensor when the result wasn't dict-shaped, raising
     IndexError: too many indices for tensor of dimension 2

2. Soften the regression test's contract from `dict` to `Mapping` so it
   exercises the actual semantic guarantee (key/value access) rather
   than the implementation detail (dict vs BatchFeature). Test guards
   against the original transformers 5.x breakage where apply_chat_template's
   return_dict default went from True to False.

Includes regression test under tests/test_mm_chat_collator.py.

Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against
attn-implementation-refactor; squash-merged from agent commits 4de886fd
+ dc9fcf4f.

Signed-off-by: Wing Lian <wing@axolotl.ai>

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
2026-05-05 10:15:18 -04:00

164 lines
5.8 KiB
Python

"""
Regression tests for MultiModalChatDataCollator shape contracts.
Guard against the transformers 5.x breakage where apply_chat_template's
own `return_dict` parameter (default False) caused it to return the raw
input_ids tensor instead of the full BatchFeature dict, leading to
IndexError: too many indices for tensor of dimension 2
when downstream code did batch["input_ids"] on the resulting tensor.
"""
from unittest.mock import MagicMock, patch
import pytest
import torch
from transformers import BatchFeature
@pytest.fixture(name="mock_processor")
def fixture_mock_processor():
"""
A mock processor whose apply_chat_template returns a BatchFeature
when called with return_dict=True (the correct call convention),
or a raw input_ids tensor when called without return_dict=True
(the broken call convention that the bug introduced).
"""
processor = MagicMock()
processor.tokenizer = MagicMock()
processor.tokenizer.pad_token_id = 0
processor.image_token = "<|image|>"
processor.tokenizer.convert_tokens_to_ids = MagicMock(return_value=128256)
batch_size, seq_len = 2, 16
input_ids = torch.ones(batch_size, seq_len, dtype=torch.long)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
batch_feature = BatchFeature(
data={
"input_ids": input_ids,
"attention_mask": attention_mask,
}
)
def _apply_chat_template(*args, **kwargs):
if kwargs.get("return_dict", False):
return batch_feature
# Simulate transformers 5.x default behaviour: returns out["input_ids"]
return input_ids
processor.apply_chat_template = MagicMock(side_effect=_apply_chat_template)
processor.chat_template = None
return processor
@pytest.fixture(name="mock_processing_strategy")
def fixture_mock_processing_strategy(mock_processor):
from axolotl.processing_strategies import ProcessingStrategy
strategy = ProcessingStrategy(processor=mock_processor)
return strategy
class TestMultiModalChatDataCollatorShapeContract:
"""
Verify that MultiModalChatDataCollator.process_rows returns a dict with
2-D input_ids and labels, not a raw tensor. This is the shape contract
that process_labels depends on.
"""
def _make_collator(self, mock_processing_strategy):
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
tokenizer = mock_processing_strategy.processor.tokenizer
return MultiModalChatDataCollator(
tokenizer=tokenizer,
processing_strategy=mock_processing_strategy,
)
def _make_examples(self):
return [
{
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
}
]
def test_process_rows_returns_dict(self, mock_processing_strategy):
"""batch must be a dict, not a raw tensor."""
collator = self._make_collator(mock_processing_strategy)
examples = self._make_examples()
with patch.object(
mock_processing_strategy,
"__call__",
return_value=examples,
):
batch = collator.process_rows(examples)
assert isinstance(batch, dict), (
"process_rows must return a dict (BatchFeature), not a raw tensor. "
"If it returns a tensor, apply_chat_template was called without "
"return_dict=True at the top level."
)
def test_process_rows_input_ids_shape(self, mock_processing_strategy):
"""batch['input_ids'] must be a 2-D tensor (batch, seq_len)."""
collator = self._make_collator(mock_processing_strategy)
examples = self._make_examples()
with patch.object(
mock_processing_strategy,
"__call__",
return_value=examples,
):
batch = collator.process_rows(examples)
assert "input_ids" in batch
assert isinstance(batch["input_ids"], torch.Tensor)
assert batch["input_ids"].ndim == 2, (
f"input_ids must be 2-D (batch, seq_len), got shape {batch['input_ids'].shape}"
)
def test_process_rows_labels_shape(self, mock_processing_strategy):
"""batch['labels'] must be a 2-D tensor matching input_ids shape."""
collator = self._make_collator(mock_processing_strategy)
examples = self._make_examples()
with patch.object(
mock_processing_strategy,
"__call__",
return_value=examples,
):
batch = collator.process_rows(examples)
assert "labels" in batch
assert isinstance(batch["labels"], torch.Tensor)
assert batch["labels"].ndim == 2
assert batch["labels"].shape == batch["input_ids"].shape
def test_apply_chat_template_called_with_return_dict_true(
self, mock_processing_strategy
):
"""apply_chat_template must be called with return_dict=True as a keyword arg."""
collator = self._make_collator(mock_processing_strategy)
examples = self._make_examples()
with patch.object(
mock_processing_strategy,
"__call__",
return_value=examples,
):
collator.process_rows(examples)
call_kwargs = (
mock_processing_strategy.processor.apply_chat_template.call_args.kwargs
)
assert call_kwargs.get("return_dict") is True, (
"apply_chat_template must be called with return_dict=True as a top-level "
"keyword argument (not inside processor_kwargs). In transformers 5.x, "
"apply_chat_template has its own return_dict param (default False) that "
"controls whether it returns the full BatchFeature or just input_ids."
)