From 866d7b304061eb47246e4f48cc6d8a6229537325 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 11 Dec 2024 14:51:53 -0500 Subject: [PATCH] initial diff attn layer / model conversion implementation (support for llama arch) --- .../integrations/diff_transformer/__init__.py | 0 .../integrations/diff_transformer/convert.py | 48 ++++ .../diff_transformer/multihead_diffattn.py | 230 ++++++++++++++++++ 3 files changed, 278 insertions(+) create mode 100644 src/axolotl/integrations/diff_transformer/__init__.py create mode 100644 src/axolotl/integrations/diff_transformer/convert.py create mode 100644 src/axolotl/integrations/diff_transformer/multihead_diffattn.py diff --git a/src/axolotl/integrations/diff_transformer/__init__.py b/src/axolotl/integrations/diff_transformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py new file mode 100644 index 000000000..93a8df073 --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -0,0 +1,48 @@ +"""Differential attention conversion logic for a huggingface pre-trained model.""" +import logging + +from transformers import PreTrainedModel +from transformers.models.llama.modeling_llama import LlamaAttention +from transformers.models.mistral.modeling_mistral import MistralAttention +from transformers.models.mixtral.modeling_mixtral import MixtralAttention + +from .multihead_diffattn import DifferentialAttention + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel: + """Convert a pre-trained model's attention layers to differential attention""" + attention_patterns = (LlamaAttention, MistralAttention, MixtralAttention) + layer_idx = 0 + + # Get model dtype from existing weights + model_dtype = next(model.parameters()).dtype + + def convert_module(module): + nonlocal layer_idx + + # Iterate through module children, convert any attn layers to diff attn + for name, child in module.named_children(): + if isinstance(child, attention_patterns): + layer_type = type(child).__name__ + logger.info(f"Converting attention layer {layer_idx}: {layer_type}") + + # Create new diff attn layer + new_attention = DifferentialAttention( + config=module.config if hasattr(module, "config") else model.config, + layer_idx=layer_idx, + dtype=model_dtype, + ) + + # Replace the layer + setattr(module, name, new_attention) + layer_idx += 1 + elif len(list(child.children())) > 0: + convert_module(child) + + convert_module(model) + logger.info(f"Converted {layer_idx} attention layers to differential attention") + + return model diff --git a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py new file mode 100644 index 000000000..00462475e --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py @@ -0,0 +1,230 @@ +"""Re-implemention of differential attention.""" +# pylint: disable=invalid-name +import logging +import math +from typing import Any, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.cache_utils import Cache +from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm +from transformers.models.llama.modeling_llama import ( + LlamaRotaryEmbedding, + apply_rotary_pos_emb, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + batch_size, n_kv_heads, slen, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, None, :, :] + .expand(batch_size, n_kv_heads, n_rep, slen, head_dim) + .reshape(batch_size, n_kv_heads * n_rep, slen, head_dim) + ) + + +def lambda_init_fn(depth): + return 0.8 - 0.6 * math.exp(-0.3 * depth) + + +class DifferentialAttention(nn.Module): + """Differential Attention implementation as described in the Diff Transformer paper. + + This implements a modified attention mechanism that computes the difference between + two attention patterns, scaled by learned lambda parameters. The mechanism helps + reduce noise in the attention weights for irrelevant / less relevant tokens. + + Key components: + - Split head dimension for differential computation + - Learned lambda parameters that control attention scaling + - Sublayer normalization on the attention output + + See: + - https://arxiv.org/abs/2410.05258 + - https://github.com/microsoft/unilm/tree/master/Diff-Transformer + + Args: + config: Model configuration object containing hidden size, number of heads etc. + layer_idx: Index of this layer in the transformer stack + dtype: Data type for the layer parameters + is_causal: Whether to use causal (masked) attention + """ + + def __init__( + self, + config: Any, + layer_idx: int, + dtype: torch.dtype, + is_causal: bool = True, + ): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.is_causal = is_causal + # self.head_dim = self.hidden_size // self.num_heads + self.head_dim = self.hidden_size // self.num_heads // 2 + self.num_key_value_heads = getattr( + config, "num_key_value_heads", self.num_heads + ) + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.scaling = (self.head_dim) ** -0.5 + + # Initialize projections with correct dtype + self.q_proj = nn.Linear( + self.hidden_size, self.hidden_size, bias=False, dtype=dtype + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.hidden_size // self.num_key_value_groups, + bias=False, + dtype=dtype, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.hidden_size // self.num_key_value_groups, + bias=False, + dtype=dtype, + ) + + self.o_proj = nn.Linear( + self.hidden_size, self.hidden_size, bias=False, dtype=dtype + ) + + # Initialize differential attention parameters + self.lambda_init = lambda_init_fn(self.layer_idx) + self.lambda_q1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + ) + self.lambda_k1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + ) + self.lambda_q2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + ) + self.lambda_k2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + ) + + self.subln = RMSNorm(2 * self.head_dim, eps=1e-5) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, # pylint: disable=unused-argument + ) -> tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[tuple[torch.Tensor, torch.Tensor]], + ]: + bsz, tgt_len, _ = hidden_states.size() + + # Project queries, keys and values + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape for attention + q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, tgt_len, 2 * self.num_key_value_heads, self.head_dim).transpose( + 1, 2 + ) + v = v.view(bsz, tgt_len, self.num_key_value_heads, 2 * self.head_dim).transpose( + 1, 2 + ) + + # Generate or unpack cos, sin for rotary positional embeddings + if position_embeddings is None: + if position_ids is None: + position_ids = torch.arange( + 0, tgt_len, dtype=torch.long, device=q.device + ) + cos, sin = self.rotary_emb(q, position_ids) + else: + cos, sin = position_embeddings + + # Need to adjust cos, sin to match the halved head_dim + cos = cos[..., : self.head_dim] + sin = sin[..., : self.head_dim] + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + + # Update cache and get back concatenated states + k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) + + # Prepare for attention + k = repeat_kv(k, self.num_key_value_groups) + v = repeat_kv(v, self.num_key_value_groups) + + # Scale query + q = q * self.scaling + + # Calculate attention scores + attn_weights = torch.matmul(q, k.transpose(-1, -2)) + + # Apply causal mask + if attention_mask is None: + attention_mask = torch.triu( + torch.full((tgt_len, tgt_len), float("-inf"), device=q.device), + diagonal=1, + ).type_as(attn_weights) + attn_weights = torch.nan_to_num(attn_weights) + attn_weights = attn_weights + attention_mask + + # Apply softmax + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( + attn_weights + ) + + # Calculate lambda + lambda_1 = torch.exp( + torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() + ).type_as(q) + lambda_2 = torch.exp( + torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() + ).type_as(q) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + # Apply differential attention + attn_weights = attn_weights.view( + bsz, self.num_heads, 2, -1, attn_weights.size(-1) + ) + attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1] + + # Apply attention to values + attn = torch.matmul(attn_weights, v) + + # Apply sublayer norm + attn = self.subln(attn).type_as(attn) + attn = attn * (1 - self.lambda_init) + + # Reshape and project output + attn = attn.transpose(1, 2).reshape( + bsz, tgt_len, self.num_heads * 2 * self.head_dim + ) + attn = self.o_proj(attn) + + # Return in exact format expected by LLaMA + if output_attentions: + return attn, attn_weights, past_key_value + return attn, None, past_key_value