revert shared_kv_states workaround with transformers 5.5.4
This commit is contained in:
@@ -30,6 +30,7 @@ def _make_fused_forward(original_forward):
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
|
||||
past_key_values=None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
@@ -65,14 +66,8 @@ def _make_fused_forward(original_forward):
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
# ---- K/V path ----
|
||||
# Current transformers stores shared kv on `past_key_values.shared_layers`
|
||||
# (the legacy `shared_kv_states` decoder kwarg was removed). We mirror
|
||||
# the stock attention forward exactly so the dispatch is identical
|
||||
# regardless of whether the model was patched.
|
||||
if self.is_kv_shared_layer and past_key_values is not None:
|
||||
key_states, value_states = past_key_values.shared_layers[
|
||||
self.kv_shared_layer_index
|
||||
]
|
||||
if self.is_kv_shared_layer:
|
||||
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||
key_states = key_states.to(query_states.device)
|
||||
value_states = value_states.to(query_states.device)
|
||||
else:
|
||||
@@ -106,18 +101,12 @@ def _make_fused_forward(original_forward):
|
||||
value_states = fused_rms_norm_noscale(value_states, eps=eps)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
if past_key_values is not None:
|
||||
if not self.is_kv_shared_layer:
|
||||
key_states, value_states = past_key_values.update(
|
||||
key_states, value_states, self.layer_idx
|
||||
)
|
||||
if self.store_full_length_kv:
|
||||
if not hasattr(past_key_values, "shared_layers"):
|
||||
past_key_values.shared_layers = {}
|
||||
past_key_values.shared_layers[self.layer_idx] = (
|
||||
key_states,
|
||||
value_states,
|
||||
)
|
||||
if past_key_values is not None and not self.is_kv_shared_layer:
|
||||
key_states, value_states = past_key_values.update(
|
||||
key_states, value_states, self.layer_idx
|
||||
)
|
||||
if self.store_full_length_kv:
|
||||
shared_kv_states[self.layer_idx] = key_states, value_states
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
|
||||
Reference in New Issue
Block a user