initial diff attn layer / model conversion implementation (support for llama arch)
This commit is contained in:
48
src/axolotl/integrations/diff_transformer/convert.py
Normal file
48
src/axolotl/integrations/diff_transformer/convert.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Differential attention conversion logic for a huggingface pre-trained model."""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||||
|
from transformers.models.mistral.modeling_mistral import MistralAttention
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
|
||||||
|
|
||||||
|
from .multihead_diffattn import DifferentialAttention
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
|
||||||
|
"""Convert a pre-trained model's attention layers to differential attention"""
|
||||||
|
attention_patterns = (LlamaAttention, MistralAttention, MixtralAttention)
|
||||||
|
layer_idx = 0
|
||||||
|
|
||||||
|
# Get model dtype from existing weights
|
||||||
|
model_dtype = next(model.parameters()).dtype
|
||||||
|
|
||||||
|
def convert_module(module):
|
||||||
|
nonlocal layer_idx
|
||||||
|
|
||||||
|
# Iterate through module children, convert any attn layers to diff attn
|
||||||
|
for name, child in module.named_children():
|
||||||
|
if isinstance(child, attention_patterns):
|
||||||
|
layer_type = type(child).__name__
|
||||||
|
logger.info(f"Converting attention layer {layer_idx}: {layer_type}")
|
||||||
|
|
||||||
|
# Create new diff attn layer
|
||||||
|
new_attention = DifferentialAttention(
|
||||||
|
config=module.config if hasattr(module, "config") else model.config,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
dtype=model_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace the layer
|
||||||
|
setattr(module, name, new_attention)
|
||||||
|
layer_idx += 1
|
||||||
|
elif len(list(child.children())) > 0:
|
||||||
|
convert_module(child)
|
||||||
|
|
||||||
|
convert_module(model)
|
||||||
|
logger.info(f"Converted {layer_idx} attention layers to differential attention")
|
||||||
|
|
||||||
|
return model
|
||||||
230
src/axolotl/integrations/diff_transformer/multihead_diffattn.py
Normal file
230
src/axolotl/integrations/diff_transformer/multihead_diffattn.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""Re-implemention of differential attention."""
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaRotaryEmbedding,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||||
|
batch_size, n_kv_heads, slen, head_dim = x.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return x
|
||||||
|
return (
|
||||||
|
x[:, :, None, :, :]
|
||||||
|
.expand(batch_size, n_kv_heads, n_rep, slen, head_dim)
|
||||||
|
.reshape(batch_size, n_kv_heads * n_rep, slen, head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def lambda_init_fn(depth):
|
||||||
|
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
||||||
|
|
||||||
|
|
||||||
|
class DifferentialAttention(nn.Module):
|
||||||
|
"""Differential Attention implementation as described in the Diff Transformer paper.
|
||||||
|
|
||||||
|
This implements a modified attention mechanism that computes the difference between
|
||||||
|
two attention patterns, scaled by learned lambda parameters. The mechanism helps
|
||||||
|
reduce noise in the attention weights for irrelevant / less relevant tokens.
|
||||||
|
|
||||||
|
Key components:
|
||||||
|
- Split head dimension for differential computation
|
||||||
|
- Learned lambda parameters that control attention scaling
|
||||||
|
- Sublayer normalization on the attention output
|
||||||
|
|
||||||
|
See:
|
||||||
|
- https://arxiv.org/abs/2410.05258
|
||||||
|
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Model configuration object containing hidden size, number of heads etc.
|
||||||
|
layer_idx: Index of this layer in the transformer stack
|
||||||
|
dtype: Data type for the layer parameters
|
||||||
|
is_causal: Whether to use causal (masked) attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Any,
|
||||||
|
layer_idx: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
is_causal: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.is_causal = is_causal
|
||||||
|
# self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads // 2
|
||||||
|
self.num_key_value_heads = getattr(
|
||||||
|
config, "num_key_value_heads", self.num_heads
|
||||||
|
)
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.scaling = (self.head_dim) ** -0.5
|
||||||
|
|
||||||
|
# Initialize projections with correct dtype
|
||||||
|
self.q_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.hidden_size, bias=False, dtype=dtype
|
||||||
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size // self.num_key_value_groups,
|
||||||
|
bias=False,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size // self.num_key_value_groups,
|
||||||
|
bias=False,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.hidden_size, bias=False, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize differential attention parameters
|
||||||
|
self.lambda_init = lambda_init_fn(self.layer_idx)
|
||||||
|
self.lambda_q1 = nn.Parameter(
|
||||||
|
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
|
||||||
|
)
|
||||||
|
self.lambda_k1 = nn.Parameter(
|
||||||
|
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
|
||||||
|
)
|
||||||
|
self.lambda_q2 = nn.Parameter(
|
||||||
|
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
|
||||||
|
)
|
||||||
|
self.lambda_k2 = nn.Parameter(
|
||||||
|
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.subln = RMSNorm(2 * self.head_dim, eps=1e-5)
|
||||||
|
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
) -> tuple[
|
||||||
|
torch.Tensor,
|
||||||
|
Optional[torch.Tensor],
|
||||||
|
Optional[tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
]:
|
||||||
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# Project queries, keys and values
|
||||||
|
q = self.q_proj(hidden_states)
|
||||||
|
k = self.k_proj(hidden_states)
|
||||||
|
v = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Reshape for attention
|
||||||
|
q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(bsz, tgt_len, 2 * self.num_key_value_heads, self.head_dim).transpose(
|
||||||
|
1, 2
|
||||||
|
)
|
||||||
|
v = v.view(bsz, tgt_len, self.num_key_value_heads, 2 * self.head_dim).transpose(
|
||||||
|
1, 2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate or unpack cos, sin for rotary positional embeddings
|
||||||
|
if position_embeddings is None:
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(
|
||||||
|
0, tgt_len, dtype=torch.long, device=q.device
|
||||||
|
)
|
||||||
|
cos, sin = self.rotary_emb(q, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
|
# Need to adjust cos, sin to match the halved head_dim
|
||||||
|
cos = cos[..., : self.head_dim]
|
||||||
|
sin = sin[..., : self.head_dim]
|
||||||
|
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
|
||||||
|
# Update cache and get back concatenated states
|
||||||
|
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# Prepare for attention
|
||||||
|
k = repeat_kv(k, self.num_key_value_groups)
|
||||||
|
v = repeat_kv(v, self.num_key_value_groups)
|
||||||
|
|
||||||
|
# Scale query
|
||||||
|
q = q * self.scaling
|
||||||
|
|
||||||
|
# Calculate attention scores
|
||||||
|
attn_weights = torch.matmul(q, k.transpose(-1, -2))
|
||||||
|
|
||||||
|
# Apply causal mask
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.triu(
|
||||||
|
torch.full((tgt_len, tgt_len), float("-inf"), device=q.device),
|
||||||
|
diagonal=1,
|
||||||
|
).type_as(attn_weights)
|
||||||
|
attn_weights = torch.nan_to_num(attn_weights)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# Apply softmax
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
|
||||||
|
attn_weights
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate lambda
|
||||||
|
lambda_1 = torch.exp(
|
||||||
|
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
||||||
|
).type_as(q)
|
||||||
|
lambda_2 = torch.exp(
|
||||||
|
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||||
|
).type_as(q)
|
||||||
|
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||||
|
|
||||||
|
# Apply differential attention
|
||||||
|
attn_weights = attn_weights.view(
|
||||||
|
bsz, self.num_heads, 2, -1, attn_weights.size(-1)
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
|
||||||
|
|
||||||
|
# Apply attention to values
|
||||||
|
attn = torch.matmul(attn_weights, v)
|
||||||
|
|
||||||
|
# Apply sublayer norm
|
||||||
|
attn = self.subln(attn).type_as(attn)
|
||||||
|
attn = attn * (1 - self.lambda_init)
|
||||||
|
|
||||||
|
# Reshape and project output
|
||||||
|
attn = attn.transpose(1, 2).reshape(
|
||||||
|
bsz, tgt_len, self.num_heads * 2 * self.head_dim
|
||||||
|
)
|
||||||
|
attn = self.o_proj(attn)
|
||||||
|
|
||||||
|
# Return in exact format expected by LLaMA
|
||||||
|
if output_attentions:
|
||||||
|
return attn, attn_weights, past_key_value
|
||||||
|
return attn, None, past_key_value
|
||||||
Reference in New Issue
Block a user