feat: migrate to transformers 4.48 attention sig
This commit is contained in:
@@ -192,15 +192,6 @@ class LolcatsLinearAttention(nn.Module):
|
||||
|
||||
self.remove_base_attn = remove_base_attn
|
||||
|
||||
# Rotary embeddings (patch for Llama 3.1, Transformer v4.43.0)
|
||||
self.rotary_config = rotary_config
|
||||
# if isinstance(self.rotary_config, DictDefault):
|
||||
# self.rotary_config = OmegaConf.to_container(self.rotary_config)
|
||||
|
||||
self.rotary_emb = None
|
||||
if self.base_config is not None and self.rotary_config is None:
|
||||
self.rotary_emb = base_attn.rotary_emb
|
||||
|
||||
self.init_weights_(base_attn, remove_base_attn)
|
||||
self.init_feature_map_(
|
||||
feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs
|
||||
@@ -244,33 +235,15 @@ class LolcatsLinearAttention(nn.Module):
|
||||
"""
|
||||
# Make other attributes accessible
|
||||
self.attention_dropout = 0 # We don't use dropout
|
||||
self.hidden_size = base_attn.hidden_size
|
||||
self.num_heads = base_attn.num_heads
|
||||
self.hidden_size = base_attn.config.hidden_size
|
||||
self.num_heads = base_attn.config.num_attention_heads
|
||||
self.head_dim = base_attn.head_dim
|
||||
self.num_key_value_heads = base_attn.num_key_value_heads
|
||||
self.num_key_value_heads = base_attn.config.num_key_value_heads
|
||||
self.num_key_value_groups = base_attn.num_key_value_groups
|
||||
|
||||
self.q_shape = [self.num_heads, self.head_dim]
|
||||
self.k_shape = [self.num_key_value_heads, self.head_dim]
|
||||
self.v_shape = [self.num_key_value_heads, self.head_dim]
|
||||
device = base_attn.q_proj.weight.device
|
||||
# Rotary embeddings
|
||||
if self.rotary_emb is None:
|
||||
self.max_position_embeddings = base_attn.max_position_embeddings
|
||||
scaling_factor = getattr(base_attn.rotary_emb, "scaling_factor", 1.0)
|
||||
if self.rotary_config is None:
|
||||
self.rotary_emb = get_rotary_embeddings(
|
||||
rope_scaling_type=None,
|
||||
head_dim=self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings, # base_attn.rotary_emb.max_position_embeddings,
|
||||
rope_theta=base_attn.rotary_emb.base,
|
||||
rope_scaling_factor=scaling_factor, # base_attn.rotary_emb.scaling_factor,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
if "device" not in self.rotary_config:
|
||||
self.rotary_config["device"] = device
|
||||
self.rotary_emb = get_rotary_embeddings(**self.rotary_config)
|
||||
|
||||
# Copy original model projection layers
|
||||
self.q_proj = base_attn.q_proj
|
||||
@@ -293,9 +266,9 @@ class LolcatsLinearAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_value: Optional[Any] = None,
|
||||
): # "legacy" cache approach
|
||||
):
|
||||
"""
|
||||
Compute queries, keys, and values
|
||||
"""
|
||||
@@ -325,18 +298,9 @@ class LolcatsLinearAttention(nn.Module):
|
||||
else:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
# Apply rotary embeddings and repeat for GQA
|
||||
if position_ids is not None and kv_seq_len <= position_ids[0, -1]:
|
||||
kv_seq_len = position_ids[0, -1] + 1 # hack for adjusting position ids
|
||||
|
||||
if self.rotary_emb is None:
|
||||
raise ValueError("Rotary embeddings not initialized")
|
||||
|
||||
try: # As in Transformers v4.36
|
||||
cos, sin = self.rotary_emb(k, seq_len=kv_seq_len)
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
||||
except TypeError: # As in Transformers v4.39+
|
||||
cos, sin = self.rotary_emb(v, position_ids)
|
||||
# Apply rotary embeddings
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
|
||||
k = repeat_kv(k, self.num_key_value_groups)
|
||||
@@ -347,7 +311,7 @@ class LolcatsLinearAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_value: Optional[Any] = None, # "legacy" cache approach
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
@@ -359,7 +323,7 @@ class LolcatsLinearAttention(nn.Module):
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_ids, past_key_value
|
||||
hidden_states, attention_mask, position_embeddings, past_key_value
|
||||
)
|
||||
|
||||
if self.base_inference:
|
||||
@@ -456,7 +420,7 @@ class LolcatsLinearAttention(nn.Module):
|
||||
y_true = self.o_proj(y_true)
|
||||
attn_weights = None
|
||||
|
||||
return y_true, attn_weights, past_key_value
|
||||
return y_true, attn_weights
|
||||
|
||||
|
||||
class LinearAttentionState(Cache):
|
||||
|
||||
Reference in New Issue
Block a user