revert shared_kv_states workaround with transformers 5.5.4
This commit is contained in:
@@ -3,16 +3,16 @@
|
||||
These tests exercise the patched ``Gemma4TextAttention.forward`` against
|
||||
the stock implementation it replaces. The hybrid Gemma 4 model intentionally
|
||||
mixes a sliding (`head_dim=32`) layer with a full-attention proportional-rope
|
||||
layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that:
|
||||
layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that the
|
||||
partial-rotary RMSNorm+RoPE path through the fused Triton kernel is
|
||||
exercised end-to-end (this is the bug originally documented in
|
||||
``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``).
|
||||
|
||||
1. The partial-rotary RMSNorm+RoPE path through the fused Triton kernel
|
||||
gets exercised end-to-end (this is the bug originally documented in
|
||||
``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``).
|
||||
2. The fused forward must match the current transformers attention API,
|
||||
where the decoder layer no longer passes a ``shared_kv_states`` kwarg
|
||||
and shared kv lives on ``past_key_values.shared_layers``. An older
|
||||
fused_forward signature would raise ``TypeError: ... missing 1
|
||||
required positional argument: 'shared_kv_states'`` here.
|
||||
The full-model forward also pins that the fused forward keeps accepting
|
||||
whatever call shape ``Gemma4TextDecoderLayer.forward`` produces in the
|
||||
installed transformers version — so any future signature drift on
|
||||
upstream's side trips a clear failure here instead of a confusing
|
||||
TypeError deep in a training run.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
@@ -86,15 +86,13 @@ def _build_model(seed=0):
|
||||
|
||||
class TestFusedAttnSignature:
|
||||
"""The fused forward must accept the same call shape as
|
||||
``Gemma4TextDecoderLayer`` produces under the current transformers API
|
||||
(no ``shared_kv_states`` kwarg)."""
|
||||
``Gemma4TextDecoderLayer`` produces in the installed transformers
|
||||
version. Any signature drift surfaces here as a TypeError."""
|
||||
|
||||
def test_decoder_layer_can_call_fused_forward(self, restore_gemma4_attention):
|
||||
"""Regression for the API drift: decoder layer calls
|
||||
``self.self_attn(hidden_states=..., position_embeddings=...,
|
||||
attention_mask=..., position_ids=..., past_key_values=...)`` and
|
||||
nothing else. A signature with a positional ``shared_kv_states``
|
||||
used to raise ``TypeError`` here before reaching the kernel."""
|
||||
"""Run a model forward that exercises the real
|
||||
``Gemma4TextDecoderLayer -> Gemma4TextAttention`` call path with
|
||||
the fused patch installed."""
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
@@ -126,6 +124,7 @@ class TestFusedAttnPerLayerCorrectness:
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=(cos, sin),
|
||||
attention_mask=None,
|
||||
shared_kv_states={},
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
Reference in New Issue
Block a user