feat: migrate to transformers 4.48 attention sig
This commit is contained in:
@@ -192,15 +192,6 @@ class LolcatsLinearAttention(nn.Module):
|
|||||||
|
|
||||||
self.remove_base_attn = remove_base_attn
|
self.remove_base_attn = remove_base_attn
|
||||||
|
|
||||||
# Rotary embeddings (patch for Llama 3.1, Transformer v4.43.0)
|
|
||||||
self.rotary_config = rotary_config
|
|
||||||
# if isinstance(self.rotary_config, DictDefault):
|
|
||||||
# self.rotary_config = OmegaConf.to_container(self.rotary_config)
|
|
||||||
|
|
||||||
self.rotary_emb = None
|
|
||||||
if self.base_config is not None and self.rotary_config is None:
|
|
||||||
self.rotary_emb = base_attn.rotary_emb
|
|
||||||
|
|
||||||
self.init_weights_(base_attn, remove_base_attn)
|
self.init_weights_(base_attn, remove_base_attn)
|
||||||
self.init_feature_map_(
|
self.init_feature_map_(
|
||||||
feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs
|
feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs
|
||||||
@@ -244,33 +235,15 @@ class LolcatsLinearAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# Make other attributes accessible
|
# Make other attributes accessible
|
||||||
self.attention_dropout = 0 # We don't use dropout
|
self.attention_dropout = 0 # We don't use dropout
|
||||||
self.hidden_size = base_attn.hidden_size
|
self.hidden_size = base_attn.config.hidden_size
|
||||||
self.num_heads = base_attn.num_heads
|
self.num_heads = base_attn.config.num_attention_heads
|
||||||
self.head_dim = base_attn.head_dim
|
self.head_dim = base_attn.head_dim
|
||||||
self.num_key_value_heads = base_attn.num_key_value_heads
|
self.num_key_value_heads = base_attn.config.num_key_value_heads
|
||||||
self.num_key_value_groups = base_attn.num_key_value_groups
|
self.num_key_value_groups = base_attn.num_key_value_groups
|
||||||
|
|
||||||
self.q_shape = [self.num_heads, self.head_dim]
|
self.q_shape = [self.num_heads, self.head_dim]
|
||||||
self.k_shape = [self.num_key_value_heads, self.head_dim]
|
self.k_shape = [self.num_key_value_heads, self.head_dim]
|
||||||
self.v_shape = [self.num_key_value_heads, self.head_dim]
|
self.v_shape = [self.num_key_value_heads, self.head_dim]
|
||||||
device = base_attn.q_proj.weight.device
|
|
||||||
# Rotary embeddings
|
|
||||||
if self.rotary_emb is None:
|
|
||||||
self.max_position_embeddings = base_attn.max_position_embeddings
|
|
||||||
scaling_factor = getattr(base_attn.rotary_emb, "scaling_factor", 1.0)
|
|
||||||
if self.rotary_config is None:
|
|
||||||
self.rotary_emb = get_rotary_embeddings(
|
|
||||||
rope_scaling_type=None,
|
|
||||||
head_dim=self.head_dim,
|
|
||||||
max_position_embeddings=self.max_position_embeddings, # base_attn.rotary_emb.max_position_embeddings,
|
|
||||||
rope_theta=base_attn.rotary_emb.base,
|
|
||||||
rope_scaling_factor=scaling_factor, # base_attn.rotary_emb.scaling_factor,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if "device" not in self.rotary_config:
|
|
||||||
self.rotary_config["device"] = device
|
|
||||||
self.rotary_emb = get_rotary_embeddings(**self.rotary_config)
|
|
||||||
|
|
||||||
# Copy original model projection layers
|
# Copy original model projection layers
|
||||||
self.q_proj = base_attn.q_proj
|
self.q_proj = base_attn.q_proj
|
||||||
@@ -293,9 +266,9 @@ class LolcatsLinearAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
past_key_value: Optional[Any] = None,
|
past_key_value: Optional[Any] = None,
|
||||||
): # "legacy" cache approach
|
):
|
||||||
"""
|
"""
|
||||||
Compute queries, keys, and values
|
Compute queries, keys, and values
|
||||||
"""
|
"""
|
||||||
@@ -325,18 +298,9 @@ class LolcatsLinearAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
# Apply rotary embeddings and repeat for GQA
|
# Apply rotary embeddings
|
||||||
if position_ids is not None and kv_seq_len <= position_ids[0, -1]:
|
if position_embeddings is not None:
|
||||||
kv_seq_len = position_ids[0, -1] + 1 # hack for adjusting position ids
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
if self.rotary_emb is None:
|
|
||||||
raise ValueError("Rotary embeddings not initialized")
|
|
||||||
|
|
||||||
try: # As in Transformers v4.36
|
|
||||||
cos, sin = self.rotary_emb(k, seq_len=kv_seq_len)
|
|
||||||
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
|
||||||
except TypeError: # As in Transformers v4.39+
|
|
||||||
cos, sin = self.rotary_emb(v, position_ids)
|
|
||||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||||
|
|
||||||
k = repeat_kv(k, self.num_key_value_groups)
|
k = repeat_kv(k, self.num_key_value_groups)
|
||||||
@@ -347,7 +311,7 @@ class LolcatsLinearAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
past_key_value: Optional[Any] = None, # "legacy" cache approach
|
past_key_value: Optional[Any] = None, # "legacy" cache approach
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
@@ -359,7 +323,7 @@ class LolcatsLinearAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
b, l, _ = hidden_states.size()
|
b, l, _ = hidden_states.size()
|
||||||
q, k, v, kv_seq_len = self.process_qkv(
|
q, k, v, kv_seq_len = self.process_qkv(
|
||||||
hidden_states, attention_mask, position_ids, past_key_value
|
hidden_states, attention_mask, position_embeddings, past_key_value
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.base_inference:
|
if self.base_inference:
|
||||||
@@ -456,7 +420,7 @@ class LolcatsLinearAttention(nn.Module):
|
|||||||
y_true = self.o_proj(y_true)
|
y_true = self.o_proj(y_true)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return y_true, attn_weights, past_key_value
|
return y_true, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class LinearAttentionState(Cache):
|
class LinearAttentionState(Cache):
|
||||||
|
|||||||
Reference in New Issue
Block a user