feat: migrate to transformers 4.48 attention sig

This commit is contained in:
NanoCode012
2025-02-04 01:52:35 +07:00
parent 81731adc1d
commit 0b7b58c8be

View File

@@ -192,15 +192,6 @@ class LolcatsLinearAttention(nn.Module):
self.remove_base_attn = remove_base_attn
# Rotary embeddings (patch for Llama 3.1, Transformer v4.43.0)
self.rotary_config = rotary_config
# if isinstance(self.rotary_config, DictDefault):
# self.rotary_config = OmegaConf.to_container(self.rotary_config)
self.rotary_emb = None
if self.base_config is not None and self.rotary_config is None:
self.rotary_emb = base_attn.rotary_emb
self.init_weights_(base_attn, remove_base_attn)
self.init_feature_map_(
feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs
@@ -244,33 +235,15 @@ class LolcatsLinearAttention(nn.Module):
"""
# Make other attributes accessible
self.attention_dropout = 0 # We don't use dropout
self.hidden_size = base_attn.hidden_size
self.num_heads = base_attn.num_heads
self.hidden_size = base_attn.config.hidden_size
self.num_heads = base_attn.config.num_attention_heads
self.head_dim = base_attn.head_dim
self.num_key_value_heads = base_attn.num_key_value_heads
self.num_key_value_heads = base_attn.config.num_key_value_heads
self.num_key_value_groups = base_attn.num_key_value_groups
self.q_shape = [self.num_heads, self.head_dim]
self.k_shape = [self.num_key_value_heads, self.head_dim]
self.v_shape = [self.num_key_value_heads, self.head_dim]
device = base_attn.q_proj.weight.device
# Rotary embeddings
if self.rotary_emb is None:
self.max_position_embeddings = base_attn.max_position_embeddings
scaling_factor = getattr(base_attn.rotary_emb, "scaling_factor", 1.0)
if self.rotary_config is None:
self.rotary_emb = get_rotary_embeddings(
rope_scaling_type=None,
head_dim=self.head_dim,
max_position_embeddings=self.max_position_embeddings, # base_attn.rotary_emb.max_position_embeddings,
rope_theta=base_attn.rotary_emb.base,
rope_scaling_factor=scaling_factor, # base_attn.rotary_emb.scaling_factor,
device=device,
)
else:
if "device" not in self.rotary_config:
self.rotary_config["device"] = device
self.rotary_emb = get_rotary_embeddings(**self.rotary_config)
# Copy original model projection layers
self.q_proj = base_attn.q_proj
@@ -293,9 +266,9 @@ class LolcatsLinearAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[Any] = None,
): # "legacy" cache approach
):
"""
Compute queries, keys, and values
"""
@@ -325,18 +298,9 @@ class LolcatsLinearAttention(nn.Module):
else:
kv_seq_len += past_key_value[0].shape[-2]
# Apply rotary embeddings and repeat for GQA
if position_ids is not None and kv_seq_len <= position_ids[0, -1]:
kv_seq_len = position_ids[0, -1] + 1 # hack for adjusting position ids
if self.rotary_emb is None:
raise ValueError("Rotary embeddings not initialized")
try: # As in Transformers v4.36
cos, sin = self.rotary_emb(k, seq_len=kv_seq_len)
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
except TypeError: # As in Transformers v4.39+
cos, sin = self.rotary_emb(v, position_ids)
# Apply rotary embeddings
if position_embeddings is not None:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin)
k = repeat_kv(k, self.num_key_value_groups)
@@ -347,7 +311,7 @@ class LolcatsLinearAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[Any] = None, # "legacy" cache approach
output_attentions: bool = False,
use_cache: bool = False,
@@ -359,7 +323,7 @@ class LolcatsLinearAttention(nn.Module):
"""
b, l, _ = hidden_states.size()
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
hidden_states, attention_mask, position_embeddings, past_key_value
)
if self.base_inference:
@@ -456,7 +420,7 @@ class LolcatsLinearAttention(nn.Module):
y_true = self.o_proj(y_true)
attn_weights = None
return y_true, attn_weights, past_key_value
return y_true, attn_weights
class LinearAttentionState(Cache):