diff --git a/src/axolotl/monkeypatch/models/gemma4/fused_attn.py b/src/axolotl/monkeypatch/models/gemma4/fused_attn.py index 4a171db8f..7cb5c6beb 100644 --- a/src/axolotl/monkeypatch/models/gemma4/fused_attn.py +++ b/src/axolotl/monkeypatch/models/gemma4/fused_attn.py @@ -30,6 +30,7 @@ def _make_fused_forward(original_forward): hidden_states: torch.Tensor, position_embeddings: torch.Tensor, attention_mask: torch.Tensor | None, + shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]], past_key_values=None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -65,14 +66,8 @@ def _make_fused_forward(original_forward): query_states = query_states.transpose(1, 2) # ---- K/V path ---- - # Current transformers stores shared kv on `past_key_values.shared_layers` - # (the legacy `shared_kv_states` decoder kwarg was removed). We mirror - # the stock attention forward exactly so the dispatch is identical - # regardless of whether the model was patched. - if self.is_kv_shared_layer and past_key_values is not None: - key_states, value_states = past_key_values.shared_layers[ - self.kv_shared_layer_index - ] + if self.is_kv_shared_layer: + key_states, value_states = shared_kv_states[self.kv_shared_layer_index] key_states = key_states.to(query_states.device) value_states = value_states.to(query_states.device) else: @@ -106,18 +101,12 @@ def _make_fused_forward(original_forward): value_states = fused_rms_norm_noscale(value_states, eps=eps) value_states = value_states.transpose(1, 2) - if past_key_values is not None: - if not self.is_kv_shared_layer: - key_states, value_states = past_key_values.update( - key_states, value_states, self.layer_idx - ) - if self.store_full_length_kv: - if not hasattr(past_key_values, "shared_layers"): - past_key_values.shared_layers = {} - past_key_values.shared_layers[self.layer_idx] = ( - key_states, - value_states, - ) + if past_key_values is not None and not self.is_kv_shared_layer: + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx + ) + if self.store_full_length_kv: + shared_kv_states[self.layer_idx] = key_states, value_states attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/tests/monkeypatch/test_gemma4_fused_attn.py b/tests/monkeypatch/test_gemma4_fused_attn.py index ce8431477..0530d0ee8 100644 --- a/tests/monkeypatch/test_gemma4_fused_attn.py +++ b/tests/monkeypatch/test_gemma4_fused_attn.py @@ -3,16 +3,16 @@ These tests exercise the patched ``Gemma4TextAttention.forward`` against the stock implementation it replaces. The hybrid Gemma 4 model intentionally mixes a sliding (`head_dim=32`) layer with a full-attention proportional-rope -layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that: +layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that the +partial-rotary RMSNorm+RoPE path through the fused Triton kernel is +exercised end-to-end (this is the bug originally documented in +``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``). - 1. The partial-rotary RMSNorm+RoPE path through the fused Triton kernel - gets exercised end-to-end (this is the bug originally documented in - ``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``). - 2. The fused forward must match the current transformers attention API, - where the decoder layer no longer passes a ``shared_kv_states`` kwarg - and shared kv lives on ``past_key_values.shared_layers``. An older - fused_forward signature would raise ``TypeError: ... missing 1 - required positional argument: 'shared_kv_states'`` here. +The full-model forward also pins that the fused forward keeps accepting +whatever call shape ``Gemma4TextDecoderLayer.forward`` produces in the +installed transformers version — so any future signature drift on +upstream's side trips a clear failure here instead of a confusing +TypeError deep in a training run. """ import pytest @@ -86,15 +86,13 @@ def _build_model(seed=0): class TestFusedAttnSignature: """The fused forward must accept the same call shape as - ``Gemma4TextDecoderLayer`` produces under the current transformers API - (no ``shared_kv_states`` kwarg).""" + ``Gemma4TextDecoderLayer`` produces in the installed transformers + version. Any signature drift surfaces here as a TypeError.""" def test_decoder_layer_can_call_fused_forward(self, restore_gemma4_attention): - """Regression for the API drift: decoder layer calls - ``self.self_attn(hidden_states=..., position_embeddings=..., - attention_mask=..., position_ids=..., past_key_values=...)`` and - nothing else. A signature with a positional ``shared_kv_states`` - used to raise ``TypeError`` here before reaching the kernel.""" + """Run a model forward that exercises the real + ``Gemma4TextDecoderLayer -> Gemma4TextAttention`` call path with + the fused patch installed.""" from axolotl.monkeypatch.models.gemma4.fused_attn import ( patch_gemma4_fused_attn, ) @@ -126,6 +124,7 @@ class TestFusedAttnPerLayerCorrectness: hidden_states=hidden_states, position_embeddings=(cos, sin), attention_mask=None, + shared_kv_states={}, ) return out