diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_attention.py b/src/axolotl/integrations/lolcats/linear_attention/linear_attention.py index abce4dd93..352459641 100644 --- a/src/axolotl/integrations/lolcats/linear_attention/linear_attention.py +++ b/src/axolotl/integrations/lolcats/linear_attention/linear_attention.py @@ -192,15 +192,6 @@ class LolcatsLinearAttention(nn.Module): self.remove_base_attn = remove_base_attn - # Rotary embeddings (patch for Llama 3.1, Transformer v4.43.0) - self.rotary_config = rotary_config - # if isinstance(self.rotary_config, DictDefault): - # self.rotary_config = OmegaConf.to_container(self.rotary_config) - - self.rotary_emb = None - if self.base_config is not None and self.rotary_config is None: - self.rotary_emb = base_attn.rotary_emb - self.init_weights_(base_attn, remove_base_attn) self.init_feature_map_( feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs @@ -244,33 +235,15 @@ class LolcatsLinearAttention(nn.Module): """ # Make other attributes accessible self.attention_dropout = 0 # We don't use dropout - self.hidden_size = base_attn.hidden_size - self.num_heads = base_attn.num_heads + self.hidden_size = base_attn.config.hidden_size + self.num_heads = base_attn.config.num_attention_heads self.head_dim = base_attn.head_dim - self.num_key_value_heads = base_attn.num_key_value_heads + self.num_key_value_heads = base_attn.config.num_key_value_heads self.num_key_value_groups = base_attn.num_key_value_groups self.q_shape = [self.num_heads, self.head_dim] self.k_shape = [self.num_key_value_heads, self.head_dim] self.v_shape = [self.num_key_value_heads, self.head_dim] - device = base_attn.q_proj.weight.device - # Rotary embeddings - if self.rotary_emb is None: - self.max_position_embeddings = base_attn.max_position_embeddings - scaling_factor = getattr(base_attn.rotary_emb, "scaling_factor", 1.0) - if self.rotary_config is None: - self.rotary_emb = get_rotary_embeddings( - rope_scaling_type=None, - head_dim=self.head_dim, - max_position_embeddings=self.max_position_embeddings, # base_attn.rotary_emb.max_position_embeddings, - rope_theta=base_attn.rotary_emb.base, - rope_scaling_factor=scaling_factor, # base_attn.rotary_emb.scaling_factor, - device=device, - ) - else: - if "device" not in self.rotary_config: - self.rotary_config["device"] = device - self.rotary_emb = get_rotary_embeddings(**self.rotary_config) # Copy original model projection layers self.q_proj = base_attn.q_proj @@ -293,9 +266,9 @@ class LolcatsLinearAttention(nn.Module): self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, past_key_value: Optional[Any] = None, - ): # "legacy" cache approach + ): """ Compute queries, keys, and values """ @@ -325,18 +298,9 @@ class LolcatsLinearAttention(nn.Module): else: kv_seq_len += past_key_value[0].shape[-2] - # Apply rotary embeddings and repeat for GQA - if position_ids is not None and kv_seq_len <= position_ids[0, -1]: - kv_seq_len = position_ids[0, -1] + 1 # hack for adjusting position ids - - if self.rotary_emb is None: - raise ValueError("Rotary embeddings not initialized") - - try: # As in Transformers v4.36 - cos, sin = self.rotary_emb(k, seq_len=kv_seq_len) - q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) - except TypeError: # As in Transformers v4.39+ - cos, sin = self.rotary_emb(v, position_ids) + # Apply rotary embeddings + if position_embeddings is not None: + cos, sin = position_embeddings q, k = apply_rotary_pos_emb(q, k, cos, sin) k = repeat_kv(k, self.num_key_value_groups) @@ -347,7 +311,7 @@ class LolcatsLinearAttention(nn.Module): self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, past_key_value: Optional[Any] = None, # "legacy" cache approach output_attentions: bool = False, use_cache: bool = False, @@ -359,7 +323,7 @@ class LolcatsLinearAttention(nn.Module): """ b, l, _ = hidden_states.size() q, k, v, kv_seq_len = self.process_qkv( - hidden_states, attention_mask, position_ids, past_key_value + hidden_states, attention_mask, position_embeddings, past_key_value ) if self.base_inference: @@ -456,7 +420,7 @@ class LolcatsLinearAttention(nn.Module): y_true = self.o_proj(y_true) attn_weights = None - return y_true, attn_weights, past_key_value + return y_true, attn_weights class LinearAttentionState(Cache):