revert shared_kv_states workaround with transformers 5.5.4
This commit is contained in:
@@ -30,6 +30,7 @@ def _make_fused_forward(original_forward):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
position_embeddings: torch.Tensor,
|
position_embeddings: torch.Tensor,
|
||||||
attention_mask: torch.Tensor | None,
|
attention_mask: torch.Tensor | None,
|
||||||
|
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
@@ -65,14 +66,8 @@ def _make_fused_forward(original_forward):
|
|||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
|
|
||||||
# ---- K/V path ----
|
# ---- K/V path ----
|
||||||
# Current transformers stores shared kv on `past_key_values.shared_layers`
|
if self.is_kv_shared_layer:
|
||||||
# (the legacy `shared_kv_states` decoder kwarg was removed). We mirror
|
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||||
# the stock attention forward exactly so the dispatch is identical
|
|
||||||
# regardless of whether the model was patched.
|
|
||||||
if self.is_kv_shared_layer and past_key_values is not None:
|
|
||||||
key_states, value_states = past_key_values.shared_layers[
|
|
||||||
self.kv_shared_layer_index
|
|
||||||
]
|
|
||||||
key_states = key_states.to(query_states.device)
|
key_states = key_states.to(query_states.device)
|
||||||
value_states = value_states.to(query_states.device)
|
value_states = value_states.to(query_states.device)
|
||||||
else:
|
else:
|
||||||
@@ -106,18 +101,12 @@ def _make_fused_forward(original_forward):
|
|||||||
value_states = fused_rms_norm_noscale(value_states, eps=eps)
|
value_states = fused_rms_norm_noscale(value_states, eps=eps)
|
||||||
value_states = value_states.transpose(1, 2)
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
if past_key_values is not None:
|
if past_key_values is not None and not self.is_kv_shared_layer:
|
||||||
if not self.is_kv_shared_layer:
|
key_states, value_states = past_key_values.update(
|
||||||
key_states, value_states = past_key_values.update(
|
key_states, value_states, self.layer_idx
|
||||||
key_states, value_states, self.layer_idx
|
)
|
||||||
)
|
if self.store_full_length_kv:
|
||||||
if self.store_full_length_kv:
|
shared_kv_states[self.layer_idx] = key_states, value_states
|
||||||
if not hasattr(past_key_values, "shared_layers"):
|
|
||||||
past_key_values.shared_layers = {}
|
|
||||||
past_key_values.shared_layers[self.layer_idx] = (
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_interface: Callable = eager_attention_forward
|
attention_interface: Callable = eager_attention_forward
|
||||||
if self.config._attn_implementation != "eager":
|
if self.config._attn_implementation != "eager":
|
||||||
|
|||||||
@@ -3,16 +3,16 @@
|
|||||||
These tests exercise the patched ``Gemma4TextAttention.forward`` against
|
These tests exercise the patched ``Gemma4TextAttention.forward`` against
|
||||||
the stock implementation it replaces. The hybrid Gemma 4 model intentionally
|
the stock implementation it replaces. The hybrid Gemma 4 model intentionally
|
||||||
mixes a sliding (`head_dim=32`) layer with a full-attention proportional-rope
|
mixes a sliding (`head_dim=32`) layer with a full-attention proportional-rope
|
||||||
layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that:
|
layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that the
|
||||||
|
partial-rotary RMSNorm+RoPE path through the fused Triton kernel is
|
||||||
|
exercised end-to-end (this is the bug originally documented in
|
||||||
|
``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``).
|
||||||
|
|
||||||
1. The partial-rotary RMSNorm+RoPE path through the fused Triton kernel
|
The full-model forward also pins that the fused forward keeps accepting
|
||||||
gets exercised end-to-end (this is the bug originally documented in
|
whatever call shape ``Gemma4TextDecoderLayer.forward`` produces in the
|
||||||
``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``).
|
installed transformers version — so any future signature drift on
|
||||||
2. The fused forward must match the current transformers attention API,
|
upstream's side trips a clear failure here instead of a confusing
|
||||||
where the decoder layer no longer passes a ``shared_kv_states`` kwarg
|
TypeError deep in a training run.
|
||||||
and shared kv lives on ``past_key_values.shared_layers``. An older
|
|
||||||
fused_forward signature would raise ``TypeError: ... missing 1
|
|
||||||
required positional argument: 'shared_kv_states'`` here.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -86,15 +86,13 @@ def _build_model(seed=0):
|
|||||||
|
|
||||||
class TestFusedAttnSignature:
|
class TestFusedAttnSignature:
|
||||||
"""The fused forward must accept the same call shape as
|
"""The fused forward must accept the same call shape as
|
||||||
``Gemma4TextDecoderLayer`` produces under the current transformers API
|
``Gemma4TextDecoderLayer`` produces in the installed transformers
|
||||||
(no ``shared_kv_states`` kwarg)."""
|
version. Any signature drift surfaces here as a TypeError."""
|
||||||
|
|
||||||
def test_decoder_layer_can_call_fused_forward(self, restore_gemma4_attention):
|
def test_decoder_layer_can_call_fused_forward(self, restore_gemma4_attention):
|
||||||
"""Regression for the API drift: decoder layer calls
|
"""Run a model forward that exercises the real
|
||||||
``self.self_attn(hidden_states=..., position_embeddings=...,
|
``Gemma4TextDecoderLayer -> Gemma4TextAttention`` call path with
|
||||||
attention_mask=..., position_ids=..., past_key_values=...)`` and
|
the fused patch installed."""
|
||||||
nothing else. A signature with a positional ``shared_kv_states``
|
|
||||||
used to raise ``TypeError`` here before reaching the kernel."""
|
|
||||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||||
patch_gemma4_fused_attn,
|
patch_gemma4_fused_attn,
|
||||||
)
|
)
|
||||||
@@ -126,6 +124,7 @@ class TestFusedAttnPerLayerCorrectness:
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
position_embeddings=(cos, sin),
|
position_embeddings=(cos, sin),
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
shared_kv_states={},
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user