344 lines
14 KiB
Python
344 lines
14 KiB
Python
"""Tests for the Gemma 4 hybrid-attention mask fix.
|
|
|
|
These tests pin the single critical behavior: after installing the patch,
|
|
``modeling_gemma4.create_causal_mask`` passes an SDPA-overridden config to
|
|
the underlying mask builder regardless of what the caller's config says.
|
|
This is what keeps full-attention (head_dim=512) global layers from
|
|
crashing at long sequence lengths — they need a 4D SDPA-format mask, not
|
|
the 2D FA2 mask that would be built from the model-level config.
|
|
|
|
The tests use a mocked ``create_causal_mask`` so they don't have to load
|
|
a real 26B Gemma 4 model or even have access to its weights. What matters
|
|
for the bug fix is which config is handed to the mask factory, not the
|
|
factory's actual output.
|
|
"""
|
|
|
|
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
pytest.importorskip(
|
|
"transformers.models.gemma4",
|
|
reason="gemma4_hybrid_mask patch only matters when Gemma 4 is available",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def restore_gemma4_module():
|
|
"""Snapshot ``modeling_gemma4.create_causal_mask`` and restore after
|
|
each test so patch state doesn't leak across the suite."""
|
|
from transformers.models.gemma4 import modeling_gemma4
|
|
|
|
saved = modeling_gemma4.create_causal_mask
|
|
yield modeling_gemma4
|
|
modeling_gemma4.create_causal_mask = saved
|
|
# Reset the module-level flag so the next test can re-install cleanly.
|
|
from axolotl.monkeypatch import gemma4_hybrid_mask
|
|
|
|
gemma4_hybrid_mask._PATCH_APPLIED = False
|
|
|
|
|
|
def test_patch_replaces_create_causal_mask(restore_gemma4_module):
|
|
modeling_gemma4 = restore_gemma4_module
|
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
|
|
original = modeling_gemma4.create_causal_mask
|
|
assert patch_gemma4_hybrid_mask() is True
|
|
|
|
assert modeling_gemma4.create_causal_mask is not original
|
|
assert modeling_gemma4.create_causal_mask._axolotl_original is original, (
|
|
"patched wrapper must expose the original reference for teardown"
|
|
)
|
|
|
|
|
|
def test_patch_is_idempotent(restore_gemma4_module):
|
|
modeling_gemma4 = restore_gemma4_module
|
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
|
|
patch_gemma4_hybrid_mask()
|
|
wrapper_first = modeling_gemma4.create_causal_mask
|
|
|
|
# Second call must not re-wrap the already-wrapped function (which
|
|
# would leak the original reference through a chain of wrappers).
|
|
patch_gemma4_hybrid_mask()
|
|
wrapper_second = modeling_gemma4.create_causal_mask
|
|
|
|
assert wrapper_first is wrapper_second
|
|
|
|
|
|
def test_patched_mask_forces_sdpa_config(restore_gemma4_module):
|
|
"""Core invariant: when the patched wrapper is called with a config
|
|
that says ``flash_attention_2``, the underlying mask factory receives
|
|
a shallow-copied config whose ``_attn_implementation`` is ``"sdpa"``.
|
|
|
|
Without this, the full-attention global layers get a 2D FA2 mask and
|
|
crash at long seq lens with the [B, H, S, S] / [B, S] expand error.
|
|
"""
|
|
modeling_gemma4 = restore_gemma4_module
|
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
|
|
# Swap in a mock BEFORE installing the patch so the wrapper captures
|
|
# it as the "original". The mock records every call so we can inspect
|
|
# what config got passed through.
|
|
mock_factory = MagicMock(name="create_causal_mask", return_value="mask_4d")
|
|
modeling_gemma4.create_causal_mask = mock_factory
|
|
patch_gemma4_hybrid_mask()
|
|
|
|
# Caller-supplied config says FA2 (that's the model-level setting).
|
|
caller_config = SimpleNamespace(
|
|
_attn_implementation="flash_attention_2",
|
|
head_dim=512,
|
|
some_other_attr="preserved",
|
|
)
|
|
result = modeling_gemma4.create_causal_mask(
|
|
caller_config,
|
|
inputs_embeds=None,
|
|
attention_mask=None,
|
|
past_key_values=None,
|
|
position_ids=None,
|
|
)
|
|
|
|
# Wrapper returned whatever the mock returned — no transformation of
|
|
# the result itself.
|
|
assert result == "mask_4d"
|
|
|
|
# The mock was called exactly once with a config whose
|
|
# ``_attn_implementation`` is sdpa, NOT the caller's fa2.
|
|
assert mock_factory.call_count == 1
|
|
(passed_config, *_), passed_kwargs = mock_factory.call_args
|
|
assert passed_config._attn_implementation == "sdpa"
|
|
|
|
# The wrapper must NOT mutate the caller's config in place — other
|
|
# mask builders (e.g. create_sliding_window_causal_mask) read from
|
|
# the same config and must still see fa2.
|
|
assert caller_config._attn_implementation == "flash_attention_2"
|
|
|
|
# Other attributes on the config must be preserved so the underlying
|
|
# factory has everything it needs (head_dim, rope_theta, vocab_size, ...).
|
|
assert passed_config.head_dim == 512
|
|
assert passed_config.some_other_attr == "preserved"
|
|
|
|
|
|
def test_patched_wrapper_passes_through_all_kwargs(restore_gemma4_module):
|
|
"""The wrapper must forward positional + keyword args to the original
|
|
unchanged, so transformers' own call-site in Gemma4TextModel.forward
|
|
keeps working across minor transformers-version signature drift."""
|
|
modeling_gemma4 = restore_gemma4_module
|
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
|
|
mock_factory = MagicMock(return_value="mask")
|
|
modeling_gemma4.create_causal_mask = mock_factory
|
|
patch_gemma4_hybrid_mask()
|
|
|
|
caller_config = SimpleNamespace(_attn_implementation="flash_attention_2")
|
|
modeling_gemma4.create_causal_mask(
|
|
caller_config,
|
|
"positional_arg",
|
|
inputs_embeds="embeds",
|
|
attention_mask="mask_2d",
|
|
past_key_values="cache",
|
|
position_ids="positions",
|
|
or_mask_function="or_fn",
|
|
)
|
|
|
|
args, kwargs = mock_factory.call_args
|
|
# First positional (after config override) is preserved.
|
|
assert args[1] == "positional_arg"
|
|
# All kwargs are forwarded untouched.
|
|
assert kwargs["inputs_embeds"] == "embeds"
|
|
assert kwargs["attention_mask"] == "mask_2d"
|
|
assert kwargs["past_key_values"] == "cache"
|
|
assert kwargs["position_ids"] == "positions"
|
|
assert kwargs["or_mask_function"] == "or_fn"
|
|
|
|
|
|
def test_unpatch_restores_original(restore_gemma4_module):
|
|
modeling_gemma4 = restore_gemma4_module
|
|
from axolotl.monkeypatch.gemma4_hybrid_mask import (
|
|
patch_gemma4_hybrid_mask,
|
|
unpatch_gemma4_hybrid_mask,
|
|
)
|
|
|
|
sentinel = MagicMock(name="original")
|
|
modeling_gemma4.create_causal_mask = sentinel
|
|
patch_gemma4_hybrid_mask()
|
|
assert modeling_gemma4.create_causal_mask is not sentinel
|
|
|
|
unpatch_gemma4_hybrid_mask()
|
|
assert modeling_gemma4.create_causal_mask is sentinel
|
|
|
|
|
|
def test_unpatch_is_safe_without_prior_patch(restore_gemma4_module):
|
|
from axolotl.monkeypatch.gemma4_hybrid_mask import unpatch_gemma4_hybrid_mask
|
|
|
|
# Should be a no-op, no exception.
|
|
unpatch_gemma4_hybrid_mask()
|
|
|
|
|
|
def test_sliding_window_mask_builder_is_not_patched(restore_gemma4_module):
|
|
"""Only ``create_causal_mask`` is overridden — the sliding-window
|
|
factory must remain bound to its original to preserve FA2 masks for
|
|
the sliding-attention layers. If we accidentally patch both, the
|
|
sliding layers get SDPA format and lose the FA2 speedup."""
|
|
modeling_gemma4 = restore_gemma4_module
|
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
|
|
if not hasattr(modeling_gemma4, "create_sliding_window_causal_mask"):
|
|
pytest.skip("transformers version has no create_sliding_window_causal_mask")
|
|
|
|
sliding_before = modeling_gemma4.create_sliding_window_causal_mask
|
|
patch_gemma4_hybrid_mask()
|
|
sliding_after = modeling_gemma4.create_sliding_window_causal_mask
|
|
assert sliding_after is sliding_before
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration tests with a tiny randomly-initialized Gemma4TextModel.
|
|
#
|
|
# These do NOT load real 26B weights. They build a ~350k-param Gemma 4 text
|
|
# model with 2 layers (one sliding, one full_attention), apply the hybrid
|
|
# attention path end-to-end, and run a forward pass with a padded
|
|
# attention_mask at a long-ish seq len. The invariant we're pinning is that
|
|
# the full_attention layer does not crash with the
|
|
# "Target sizes: [B, H, S, S]. Tensor sizes: [B, S]"
|
|
# error — the exact failure that blew up the Gemma 4 MoE 26B pilot at ~7k
|
|
# tokens in the FSDP2 training run.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _build_tiny_gemma4_text_model():
|
|
"""Return a tiny randomly-initialized Gemma4TextModel with mixed layers."""
|
|
import torch
|
|
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
|
|
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
|
|
|
|
cfg = Gemma4TextConfig(
|
|
vocab_size=128,
|
|
hidden_size=64,
|
|
intermediate_size=128,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=2,
|
|
num_key_value_heads=2,
|
|
head_dim=32,
|
|
layer_types=["sliding_attention", "full_attention"],
|
|
sliding_window=64,
|
|
max_position_embeddings=2048,
|
|
hidden_size_per_layer_input=16,
|
|
vocab_size_per_layer_input=128,
|
|
)
|
|
# Caller-supplied attn impl simulates the pilot config (fa2 at model
|
|
# level). The hybrid patch is what makes this survive long context.
|
|
cfg._attn_implementation = "sdpa" # start safe; the test toggles fa2 later
|
|
torch.manual_seed(42)
|
|
model = Gemma4TextModel(cfg).eval()
|
|
return model, cfg
|
|
|
|
|
|
def _apply_hybrid_attn_inline(model, cfg):
|
|
"""Replicate what ``patch_manager._apply_gemma_hybrid_attention`` does
|
|
to a model, without needing a full PatchManager / pydantic cfg."""
|
|
import copy
|
|
|
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
|
|
for layer_idx, layer in enumerate(model.layers):
|
|
if cfg.layer_types[layer_idx] != "sliding_attention":
|
|
attn = getattr(layer, "self_attn", None)
|
|
if attn is not None and hasattr(attn, "config"):
|
|
sdpa_cfg = copy.copy(attn.config)
|
|
sdpa_cfg._attn_implementation = "sdpa"
|
|
attn.config = sdpa_cfg
|
|
patch_gemma4_hybrid_mask()
|
|
|
|
|
|
def test_tiny_gemma4_long_context_forward_does_not_crash(restore_gemma4_module):
|
|
"""End-to-end invariant: with the hybrid attn patch applied, a tiny
|
|
Gemma4TextModel runs a forward at long context (1024 tokens) with
|
|
real padding in the attention mask, producing the expected output
|
|
shape. This exercises the actual code path that crashed the pilot
|
|
without needing a real 26B checkpoint or CUDA."""
|
|
import torch
|
|
|
|
model, cfg = _build_tiny_gemma4_text_model()
|
|
_apply_hybrid_attn_inline(model, cfg)
|
|
|
|
B, S = 2, 1024
|
|
input_ids = torch.randint(0, cfg.vocab_size, (B, S))
|
|
attn_mask = torch.ones(B, S, dtype=torch.long)
|
|
# Pad positions in the second row. Without padding, SDPA falls back to
|
|
# ``is_causal=True`` with ``mask=None`` — we need a materialized 4D
|
|
# mask to exercise the actual bug site.
|
|
attn_mask[1, S // 2 :] = 0
|
|
|
|
with torch.no_grad():
|
|
out = model(input_ids=input_ids, attention_mask=attn_mask)
|
|
|
|
assert out.last_hidden_state.shape == (B, S, cfg.hidden_size)
|
|
assert torch.isfinite(out.last_hidden_state).all()
|
|
|
|
|
|
def test_patched_create_causal_mask_returns_4d_for_real_config(
|
|
restore_gemma4_module,
|
|
):
|
|
"""Hit the REAL ``create_causal_mask`` (not a mock) via the wrapper
|
|
and verify the returned mask is a 4D tensor — which is the shape the
|
|
SDPA-patched global layers need. Without the patch and with a
|
|
caller-supplied FA2 config this would return a 2D mask and the layer
|
|
would crash at long context."""
|
|
import torch
|
|
from transformers.cache_utils import DynamicCache
|
|
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
|
|
|
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
|
|
patch_gemma4_hybrid_mask()
|
|
modeling_gemma4 = restore_gemma4_module
|
|
|
|
cfg = Gemma4TextConfig(
|
|
vocab_size=128,
|
|
hidden_size=64,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=2,
|
|
num_key_value_heads=2,
|
|
head_dim=32,
|
|
layer_types=["sliding_attention", "full_attention"],
|
|
sliding_window=64,
|
|
max_position_embeddings=2048,
|
|
hidden_size_per_layer_input=16,
|
|
vocab_size_per_layer_input=128,
|
|
)
|
|
# Simulate the pilot: caller says flash_attention_2, but global layers
|
|
# were switched to SDPA per-layer. Without the patch, create_causal_mask
|
|
# would return an FA2 2D mask here and the SDPA layer would crash.
|
|
cfg._attn_implementation = "flash_attention_2"
|
|
|
|
B, S = 2, 1024
|
|
inputs_embeds = torch.zeros((B, S, cfg.hidden_size), dtype=torch.float32)
|
|
attention_mask = torch.ones((B, S), dtype=torch.long)
|
|
attention_mask[1, S // 2 :] = 0 # force the 4D materialized path
|
|
position_ids = torch.arange(S).unsqueeze(0).expand(B, -1)
|
|
past_key_values = DynamicCache(config=cfg)
|
|
|
|
mask = modeling_gemma4.create_causal_mask(
|
|
config=cfg,
|
|
inputs_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
past_key_values=past_key_values,
|
|
position_ids=position_ids,
|
|
)
|
|
|
|
assert mask is not None
|
|
assert isinstance(mask, torch.Tensor)
|
|
assert mask.dim() == 4, (
|
|
f"expected a 4D SDPA-format mask, got {mask.dim()}D "
|
|
f"shape={tuple(mask.shape)}. The full_attention global layers need "
|
|
"this shape or they crash at long context."
|
|
)
|
|
assert mask.shape[0] == B
|
|
assert mask.shape[-1] == S
|
|
assert mask.shape[-2] == S
|
|
|
|
# Caller's config must be untouched — other code paths still read it.
|
|
assert cfg._attn_implementation == "flash_attention_2"
|