From 198f01f90216438277e30c8ee2d4013669b6b659 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 26 Dec 2024 20:27:18 -0500 Subject: [PATCH] remove bias term in phi and add custom modeling code --- src/axolotl/cli/integrations/convert_rala.py | 4 +- src/axolotl/integrations/rala/__init__.py | 2 +- .../integrations/rala/auto/__init__.py | 0 .../integrations/rala/auto/llama/__init__.py | 0 .../integrations/rala/auto/llama/attention.py | 280 +++++++++ .../rala/auto/llama/configuration_rala.py | 5 + .../rala/auto/llama/modeling_rala.py | 544 ++++++++++++++++++ src/axolotl/integrations/rala/convert.py | 10 +- src/axolotl/integrations/rala/rala_attn.py | 4 +- 9 files changed, 843 insertions(+), 6 deletions(-) create mode 100644 src/axolotl/integrations/rala/auto/__init__.py create mode 100644 src/axolotl/integrations/rala/auto/llama/__init__.py create mode 100644 src/axolotl/integrations/rala/auto/llama/attention.py create mode 100644 src/axolotl/integrations/rala/auto/llama/configuration_rala.py create mode 100644 src/axolotl/integrations/rala/auto/llama/modeling_rala.py diff --git a/src/axolotl/cli/integrations/convert_rala.py b/src/axolotl/cli/integrations/convert_rala.py index 149cf70cb..d33348b59 100644 --- a/src/axolotl/cli/integrations/convert_rala.py +++ b/src/axolotl/cli/integrations/convert_rala.py @@ -109,9 +109,7 @@ def convert_rala(cfg, cli_args, config_path): modified_cfg["base_model"] = cfg.output_dir modified_cfg["rala_attention"] = True - plugin_class = ( - "axolotl.integrations.rala.RalaPlugin" - ) + plugin_class = "axolotl.integrations.rala.RalaPlugin" if "plugins" in modified_cfg: modified_cfg["plugins"].append(plugin_class) else: diff --git a/src/axolotl/integrations/rala/__init__.py b/src/axolotl/integrations/rala/__init__.py index 527f9d910..f9ec54cdb 100644 --- a/src/axolotl/integrations/rala/__init__.py +++ b/src/axolotl/integrations/rala/__init__.py @@ -5,7 +5,7 @@ import logging from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES from axolotl.integrations.base import BasePlugin -from axolotl.integrations.rala.rala_attn import LlamaRALAAttention +from axolotl.integrations.rala.auto.llama.attention import LlamaRALAAttention LOG = logging.getLogger(__name__) diff --git a/src/axolotl/integrations/rala/auto/__init__.py b/src/axolotl/integrations/rala/auto/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/integrations/rala/auto/llama/__init__.py b/src/axolotl/integrations/rala/auto/llama/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/integrations/rala/auto/llama/attention.py b/src/axolotl/integrations/rala/auto/llama/attention.py new file mode 100644 index 000000000..f62289216 --- /dev/null +++ b/src/axolotl/integrations/rala/auto/llama/attention.py @@ -0,0 +1,280 @@ +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import Cache +from transformers.models.llama.modeling_llama import ( + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + repeat_kv, +) + + +def kappa(x: torch.Tensor) -> torch.Tensor: + """ + The paper uses κ(x) = ELU(x) + 1. + x is assumed to be [batch, n_heads, seq_len, head_dim]. + """ + return F.elu(x) + 1 + + +class LlamaRALAAttention(nn.Module): + """ + LlamaAttention replaced with Rank-Augmented Linear Attention (RALA). + Adapted from the standard LlamaAttention for demonstration. + **Not** a fully drop-in replacement if you need caching/TP. + """ + + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + # Same Q, K, V, output projections + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + self.hidden_size, self.hidden_size, bias=config.attention_bias + ) + + # We will preserve rope usage + self._init_rope() + + # A simple φ-projection for RALA: + # The paper uses φ(x) as a linear transform or identity. We'll do a linear: + self.phi = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + + def _init_rope(self): + # Standard Llama rope logic + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, # pylint: disable=unused-argument + ): + """ + RALA forward pass. + This version omits incremental decoding with `past_key_value` for simplicity + (linear attention caching is non-trivial). + """ + bsz, q_len, _ = hidden_states.size() + + # Standard Q, K, V + query_states = self.q_proj(hidden_states) # [b, seq, n_heads*dim] + key_states = self.k_proj(hidden_states) # [b, seq, n_kv_heads*dim] + value_states = self.v_proj(hidden_states) # [b, seq, n_kv_heads*dim] + + # Reshape to [b, n_heads, seq_len, head_dim] + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + # Apply RoPE (rotary embeddings) just as in standard Llama + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + # If you still want to handle the repeated KV for multi-group setups: + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Now we apply RALA. + + # 1) Apply κ(.) to Q,K: shape [b, n_heads, seq_len, head_dim] + Q_kappa = kappa(query_states) + K_kappa = kappa(key_states) + + # 2) Compute global query Q_g = average of Q_kappa across seq_len => [b, n_heads, head_dim] + # The paper denotes Q_g = (1/N) Σ_i Q_kappa_i + seq_len_float = float(q_len) # for scaling + Q_g = Q_kappa.mean(dim=2) # [b, n_heads, head_dim] + + # 3) Compute alpha_j for each token j in [0..seq_len-1] + # alpha_j = N * softmax( Q_g · K_kappa_j^T ), shape => [b, n_heads, seq_len] + # Dot product over head_dim + # K_kappa is [b, n_heads, seq_len, head_dim], Q_g is [b, n_heads, head_dim] + # We'll do an einsum or transpose to produce logits [b, n_heads, seq_len] + + # Dot product across the last dimension (d_head), resulting in shape [b, n_heads, seq_len] + # logits = torch.einsum("bnh, bnsh -> bns", Q_g, K_kappa) # [b, n_heads, seq_len] + logits = (Q_g.unsqueeze(2) * K_kappa).sum( + dim=-1 + ) # -> [b, n_heads, seq_len] # identical to above but torch.compile should work + + # 4) Incorporate causal or padding mask if provided. + # In standard Llama, attention_mask is broadcast as [b, 1, seq_len, seq_len] or similar. + # For RALA, we only do a single softmax over "j" dimension. We can add the mask to logits. + # Caution: This might not replicate strict causal linear attention. It's a best-effort approach. + if attention_mask is not None: + # Usually Llama's causal mask is [b, 1, q_len, kv_len] with 0 or -inf + # We want shape [b, n_heads, seq_len], so we can broadcast accordingly: + # e.g., attention_mask: [b, 1, q_len, seq_len] + # We pick the slice that corresponds to q_len vs. kv_len. + # Typically the last two dims are (q_len, kv_len). We want the kv_len dimension to be `seq_len`. + # We'll do something like: + if attention_mask.dim() == 4: + # attention_mask: [b, 1, q_len, kv_len] + # if q_len == kv_len, we can do attention_mask[:, :, :, :seq_len], then squeeze dims + mask_2d = attention_mask[:, 0, :, :q_len] # [b, q_len, seq_len] + # we only want [b, n_heads, seq_len], so we must broadcast over q_len if needed + # but in this snippet, we do a single alpha_j for each j *per head*, + # ignoring per-token Q_i. So there's a mismatch. + # A simpler approach is to apply the mask for the entire sequence if a token j is invalid for ANY i. + # That is approximate. We'll just pick the first row of q_len, or do min across i dimension... + # For demonstration, let's sum or min across i dimension to see if j is valid for ANY i. + # Or we do a "causal" approach: all tokens j>i get masked. But there's no direct i index here in alpha_j. + # We'll just do a rough approach, e.g. mask = min across the q_len dimension: + mask_1d = torch.min(mask_2d, dim=1)[ + 0 + ] # [b, seq_len], picking the worst mask across query positions + # broadcast for n_heads + mask_1d = mask_1d.unsqueeze(1).expand( + -1, self.num_heads, -1 + ) # [b, n_heads, seq_len] + logits = logits + mask_1d + else: + # Possibly it's [b, seq_len]. Then we just broadcast to [b,n_heads,seq_len]. + mask_1d = attention_mask # [b, seq_len] + mask_1d = mask_1d.unsqueeze(1).expand(-1, self.num_heads, -1) + logits = logits + mask_1d + + alpha = F.softmax(logits, dim=-1) # [b, n_heads, seq_len] + # multiply by seq_len per the formula + alpha = alpha * seq_len_float + + # 5) Construct the outer-sum: Σ_j alpha_j * (K_kappa_j^T V_j) + # The paper shows a d×d matrix formed per head. + # K_kappa: [b, n_heads, seq_len, head_dim], V: [b, n_heads, seq_len, head_dim] + # For each j, do outer product K_kappa_j (d×1) × V_j^T (1×d) => d×d + # Then multiply by alpha_j and sum over j. + # We'll do an einsum for that: [b,n_heads,seq_len,d] outer [b,n_heads,seq_len,d] => [b,n_heads,d,d] + # alpha: [b, n_heads, seq_len]. + value_states_ = value_states # [b, n_heads, seq_len, head_dim] + outer_sum = torch.einsum("bns,bnsd,bnsf->bndf", alpha, K_kappa, value_states_) + + # Explanation: + # - 'bnhs' is alpha (batch, n_heads, seq_len) + # - 'bnhsd' is K_kappa (b,n_heads,seq_len, d) + # - 'bnhsf' is V (b,n_heads,seq_len, d) + # We want [b,n_heads,d,f], which is the d×d matrix per head. + # Actually we need an outer product (K_kappa_j^T × V_j). That is [d, d]. + # The call above is not quite correct if we want K_kappa_j^T × V_j as [d,d]. + # Let's do a simpler approach: + # outer_sum = sum_j alpha_j * (K_kappa_j^T outer V_j). + # = "bnhs,bnhsd,bnhsf -> bnhdf" + # means: alpha has shape (b,n,h,s), K_kappa has shape (b,n,h,s,d), V has shape (b,n,h,s,d) + # We want to produce (b,n,h,d,d). + # So the correct einsum string is 'bnhs,bnhsd,bnhsf->bnhdf': + # alpha indexes b,n,h,s + # K_kappa indexes b,n,h,s,d => K_kappa_j + # V indexes b,n,h,s,f => V_j + # The resulting shape is (b,n,h,d,f). Great. + + # 6) For each token i, Y_i = φ(X_i) ∘ [ κ(Q_i) × outer_sum ] + # Here κ(Q_i) is shape [b,n,h,d], outer_sum is shape [b,n,h,d,d]. + # We'll do a batch matmul: result_attn = Q_kappa_i × outer_sum => [b,n,h,d] + # Then multiply elementwise by φ(X_i). + # But φ(X_i) is a single [b,seq_len,d_model], so we reshape to [b,seq_len,n,h_dim]. + # We'll do per-token i in a loop or broadcast. Let's do it in a single operation with einsum: + + # first, compute φ(X): + # X is the original hidden_states: [b, seq_len, d_model] + X_phi = self.phi(hidden_states) # [b, seq_len, d_model] + X_phi = X_phi.view(bsz, q_len, self.num_heads, self.head_dim) # [b, s, n, d] + X_phi = X_phi.transpose(1, 2) # [b, n, s, d] + + # Now for each i in [0..q_len-1], we do a matrix multiply: + # result_attn_i = Q_kappa_i [b,n,s,d] × outer_sum [b,n,d,d] => we want [b,n,s,d]. + # We'll do: + result_attn = torch.einsum("bnsd,bndf->bnsf", Q_kappa, outer_sum) # [b,n,s,d] + + # Then elementwise multiply by φ(X_i): + context_layer = X_phi * result_attn # [b,n,s,d] + + # Finally, reorder to [b, s, n, d] -> [b, s, n*d] + context_layer = context_layer.transpose(1, 2).contiguous() # [b, s, n, d] + context_layer = context_layer.view(bsz, q_len, self.hidden_size) + + # One last linear projection: + attn_output = self.o_proj(context_layer) + + # Not returning a standard attn_weights. + # If you want to return alpha as "attention," we can do so: + if output_attentions: + # alpha: [b, n_heads, seq_len], but note it's only the "global" weighting of each key, + # not a (q_len x kv_len) map like standard attention. + attn_weights = alpha + else: + attn_weights = None + + # We omit cache / past_key_value returns to keep it simpler. + return attn_output, attn_weights, None diff --git a/src/axolotl/integrations/rala/auto/llama/configuration_rala.py b/src/axolotl/integrations/rala/auto/llama/configuration_rala.py new file mode 100644 index 000000000..3b06f264a --- /dev/null +++ b/src/axolotl/integrations/rala/auto/llama/configuration_rala.py @@ -0,0 +1,5 @@ +from transformers import LlamaConfig + + +class LlamaRalaConfig(LlamaConfig): + pass diff --git a/src/axolotl/integrations/rala/auto/llama/modeling_rala.py b/src/axolotl/integrations/rala/auto/llama/modeling_rala.py new file mode 100644 index 000000000..0af64fcb8 --- /dev/null +++ b/src/axolotl/integrations/rala/auto/llama/modeling_rala.py @@ -0,0 +1,544 @@ +from typing import List, Optional, Tuple, Union, Unpack + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import Cache, GenerationMixin, LlamaForCausalLM, LlamaModel +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import ( + KwargsForCausalLM, + LlamaDecoderLayer, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, + LlamaMLP, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + repeat_kv, +) + +from .configuration_rala import LlamaRalaConfig + + +def kappa(x: torch.Tensor) -> torch.Tensor: + """ + The paper uses κ(x) = ELU(x) + 1. + x is assumed to be [batch, n_heads, seq_len, head_dim]. + """ + return F.elu(x) + 1 + + +class LlamaRALAAttention(nn.Module): + """ + LlamaAttention replaced with Rank-Augmented Linear Attention (RALA). + Adapted from the standard LlamaAttention for demonstration. + **Not** a fully drop-in replacement if you need caching/TP. + """ + + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + # Same Q, K, V, output projections + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + self.hidden_size, self.hidden_size, bias=config.attention_bias + ) + + # We will preserve rope usage + self._init_rope() + + # A simple φ-projection for RALA: + # The paper uses φ(x) as a linear transform or identity. We'll do a linear: + self.phi = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + def _init_rope(self): + # Standard Llama rope logic + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, # pylint: disable=unused-argument + ): + """ + RALA forward pass. + This version omits incremental decoding with `past_key_value` for simplicity + (linear attention caching is non-trivial). + """ + bsz, q_len, _ = hidden_states.size() + + # Standard Q, K, V + query_states = self.q_proj(hidden_states) # [b, seq, n_heads*dim] + key_states = self.k_proj(hidden_states) # [b, seq, n_kv_heads*dim] + value_states = self.v_proj(hidden_states) # [b, seq, n_kv_heads*dim] + + # Reshape to [b, n_heads, seq_len, head_dim] + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + # Apply RoPE (rotary embeddings) just as in standard Llama + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + # 4. If we have a past_key_value (Cache object), let it update / append + if past_key_value is not None: + # This is the normal Llama pattern + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # The .update() method returns updated (key_states, value_states) + # and typically updates internal buffers. It may also store `layer_idx` data. + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # If you still want to handle the repeated KV for multi-group setups: + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Now we apply RALA. + + # 1) Apply κ(.) to Q,K: shape [b, n_heads, seq_len, head_dim] + Q_kappa = kappa(query_states) + K_kappa = kappa(key_states) + + # 2) Compute global query Q_g = average of Q_kappa across seq_len => [b, n_heads, head_dim] + # The paper denotes Q_g = (1/N) Σ_i Q_kappa_i + seq_len_float = float(q_len) # for scaling + Q_g = Q_kappa.mean(dim=2) # [b, n_heads, head_dim] + + # 3) Compute alpha_j for each token j in [0..seq_len-1] + # alpha_j = N * softmax( Q_g · K_kappa_j^T ), shape => [b, n_heads, seq_len] + # Dot product over head_dim + # K_kappa is [b, n_heads, seq_len, head_dim], Q_g is [b, n_heads, head_dim] + # We'll do an einsum or transpose to produce logits [b, n_heads, seq_len] + + # Dot product across the last dimension (d_head), resulting in shape [b, n_heads, seq_len] + # logits = torch.einsum("bnh, bnsh -> bns", Q_g, K_kappa) # [b, n_heads, seq_len] + logits = (Q_g.unsqueeze(2) * K_kappa).sum( + dim=-1 + ) # -> [b, n_heads, seq_len] # identical to above but torch.compile should work + + # 4) Incorporate causal or padding mask if provided. + # In standard Llama, attention_mask is broadcast as [b, 1, seq_len, seq_len] or similar. + # For RALA, we only do a single softmax over "j" dimension. We can add the mask to logits. + # Caution: This might not replicate strict causal linear attention. It's a best-effort approach. + if attention_mask is not None: + # Usually Llama's causal mask is [b, 1, q_len, kv_len] with 0 or -inf + # We want shape [b, n_heads, seq_len], so we can broadcast accordingly: + # e.g., attention_mask: [b, 1, q_len, seq_len] + # We pick the slice that corresponds to q_len vs. kv_len. + # Typically the last two dims are (q_len, kv_len). We want the kv_len dimension to be `seq_len`. + # We'll do something like: + if attention_mask.dim() == 4: + # attention_mask: [b, 1, q_len, kv_len] + # if q_len == kv_len, we can do attention_mask[:, :, :, :seq_len], then squeeze dims + mask_2d = attention_mask[:, 0, :, :q_len] # [b, q_len, seq_len] + # we only want [b, n_heads, seq_len], so we must broadcast over q_len if needed + # but in this snippet, we do a single alpha_j for each j *per head*, + # ignoring per-token Q_i. So there's a mismatch. + # A simpler approach is to apply the mask for the entire sequence if a token j is invalid for ANY i. + # That is approximate. We'll just pick the first row of q_len, or do min across i dimension... + # For demonstration, let's sum or min across i dimension to see if j is valid for ANY i. + # Or we do a "causal" approach: all tokens j>i get masked. But there's no direct i index here in alpha_j. + # We'll just do a rough approach, e.g. mask = min across the q_len dimension: + mask_1d = torch.min(mask_2d, dim=1)[ + 0 + ] # [b, seq_len], picking the worst mask across query positions + # broadcast for n_heads + mask_1d = mask_1d.unsqueeze(1).expand( + -1, self.num_heads, -1 + ) # [b, n_heads, seq_len] + logits = logits + mask_1d + else: + # Possibly it's [b, seq_len]. Then we just broadcast to [b,n_heads,seq_len]. + mask_1d = attention_mask # [b, seq_len] + mask_1d = mask_1d.unsqueeze(1).expand(-1, self.num_heads, -1) + logits = logits + mask_1d + + alpha = F.softmax(logits, dim=-1) # [b, n_heads, seq_len] + # multiply by seq_len per the formula + alpha = alpha * seq_len_float + + # 5) Construct the outer-sum: Σ_j alpha_j * (K_kappa_j^T V_j) + # The paper shows a d×d matrix formed per head. + # K_kappa: [b, n_heads, seq_len, head_dim], V: [b, n_heads, seq_len, head_dim] + # For each j, do outer product K_kappa_j (d×1) × V_j^T (1×d) => d×d + # Then multiply by alpha_j and sum over j. + # We'll do an einsum for that: [b,n_heads,seq_len,d] outer [b,n_heads,seq_len,d] => [b,n_heads,d,d] + # alpha: [b, n_heads, seq_len]. + value_states_ = value_states # [b, n_heads, seq_len, head_dim] + outer_sum = torch.einsum("bns,bnsd,bnsf->bndf", alpha, K_kappa, value_states_) + + # Explanation: + # - 'bnhs' is alpha (batch, n_heads, seq_len) + # - 'bnhsd' is K_kappa (b,n_heads,seq_len, d) + # - 'bnhsf' is V (b,n_heads,seq_len, d) + # We want [b,n_heads,d,f], which is the d×d matrix per head. + # Actually we need an outer product (K_kappa_j^T × V_j). That is [d, d]. + # The call above is not quite correct if we want K_kappa_j^T × V_j as [d,d]. + # Let's do a simpler approach: + # outer_sum = sum_j alpha_j * (K_kappa_j^T outer V_j). + # = "bnhs,bnhsd,bnhsf -> bnhdf" + # means: alpha has shape (b,n,h,s), K_kappa has shape (b,n,h,s,d), V has shape (b,n,h,s,d) + # We want to produce (b,n,h,d,d). + # So the correct einsum string is 'bnhs,bnhsd,bnhsf->bnhdf': + # alpha indexes b,n,h,s + # K_kappa indexes b,n,h,s,d => K_kappa_j + # V indexes b,n,h,s,f => V_j + # The resulting shape is (b,n,h,d,f). Great. + + # 6) For each token i, Y_i = φ(X_i) ∘ [ κ(Q_i) × outer_sum ] + # Here κ(Q_i) is shape [b,n,h,d], outer_sum is shape [b,n,h,d,d]. + # We'll do a batch matmul: result_attn = Q_kappa_i × outer_sum => [b,n,h,d] + # Then multiply elementwise by φ(X_i). + # But φ(X_i) is a single [b,seq_len,d_model], so we reshape to [b,seq_len,n,h_dim]. + # We'll do per-token i in a loop or broadcast. Let's do it in a single operation with einsum: + + # first, compute φ(X): + # X is the original hidden_states: [b, seq_len, d_model] + X_phi = self.phi(hidden_states) # [b, seq_len, d_model] + X_phi = X_phi.view(bsz, q_len, self.num_heads, self.head_dim) # [b, s, n, d] + X_phi = X_phi.transpose(1, 2) # [b, n, s, d] + + # Now for each i in [0..q_len-1], we do a matrix multiply: + # result_attn_i = Q_kappa_i [b,n,s,d] × outer_sum [b,n,d,d] => we want [b,n,s,d]. + # We'll do: + result_attn = torch.einsum("bnsd,bndf->bnsf", Q_kappa, outer_sum) # [b,n,s,d] + + # Then elementwise multiply by φ(X_i): + context_layer = X_phi * result_attn # [b,n,s,d] + + # Finally, reorder to [b, s, n, d] -> [b, s, n*d] + context_layer = context_layer.transpose(1, 2).contiguous() # [b, s, n, d] + context_layer = context_layer.view(bsz, q_len, self.hidden_size) + + # One last linear projection: + attn_output = self.o_proj(context_layer) + + if output_attentions: + # alpha => [b, n_heads, (past_len + q_len)] + attn_weights = alpha + else: + attn_weights = None + + # Return 3-tuple: (attn_output, attn_weights, past_key_value) + return attn_output, attn_weights, past_key_value + + +class LlamaRalaDecoderLayer(nn.Module): + def __init__(self, config: LlamaRalaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LlamaRALAAttention(config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class LlamaRalaModel(LlamaModel): + config_class = LlamaRalaConfig + + def __init__(self, config: LlamaRalaConfig): + LlamaPreTrainedModel.__init__(self, config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + + self.layers = nn.ModuleList( + [ + LlamaRalaDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + +class LlamaRalaForCausalLM(LlamaPreTrainedModel, GenerationMixin): + config_class = LlamaRalaConfig + _no_split_modules = ["LlamaRalaDecoderLayer"] + + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = LlamaRalaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/axolotl/integrations/rala/convert.py b/src/axolotl/integrations/rala/convert.py index a9a0bb956..67a525b35 100644 --- a/src/axolotl/integrations/rala/convert.py +++ b/src/axolotl/integrations/rala/convert.py @@ -35,7 +35,6 @@ def copy_attention_weights( nn.init.zeros_(new_attn.phi.weight) else: nn.init.normal_(new_attn.phi) - nn.init.zeros_(new_attn.phi.bias) logger.debug( "Copied positive attention weights from %s to %s", @@ -85,4 +84,13 @@ def convert_to_rala( convert_module(model) logger.info(f"Converted {layer_idx} attention layers to RALA attention") + model.config.architectures = [ + "LlamaRalaForCausalLM", + ] + model.config.model_type = "llama_rala" + model.auto_map = { + "AutoConfig": "llama.configuration_rala.LlamaRalaConfig", + "AutoModel": "llama.modeling_rala.LlamaRalaModel", + "AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM", + } return model diff --git a/src/axolotl/integrations/rala/rala_attn.py b/src/axolotl/integrations/rala/rala_attn.py index d73806202..f62289216 100644 --- a/src/axolotl/integrations/rala/rala_attn.py +++ b/src/axolotl/integrations/rala/rala_attn.py @@ -166,7 +166,9 @@ class LlamaRALAAttention(nn.Module): # Dot product across the last dimension (d_head), resulting in shape [b, n_heads, seq_len] # logits = torch.einsum("bnh, bnsh -> bns", Q_g, K_kappa) # [b, n_heads, seq_len] - logits = (Q_g.unsqueeze(2) * K_kappa).sum(dim=-1) # -> [b, n_heads, seq_len] # identical to above but torch.compile should work + logits = (Q_g.unsqueeze(2) * K_kappa).sum( + dim=-1 + ) # -> [b, n_heads, seq_len] # identical to above but torch.compile should work # 4) Incorporate causal or padding mask if provided. # In standard Llama, attention_mask is broadcast as [b, 1, seq_len, seq_len] or similar.