Adding script for doing conversion; fixes and updates
This commit is contained in:
127
scripts/convert_diff_transformer.py
Normal file
127
scripts/convert_diff_transformer.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Test conversion of transformers model attention to differential attention."""
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
)
|
||||
|
||||
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention
|
||||
|
||||
|
||||
def setup_model(
|
||||
model_name: str, device: str = "cuda"
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
"""Load model and tokenizer"""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=device,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def convert_model_attention(model: AutoModelForCausalLM) -> AutoModelForCausalLM:
|
||||
"""Convert model to use differential attention"""
|
||||
try:
|
||||
model = convert_to_diff_attention(model)
|
||||
return model
|
||||
except Exception as exception:
|
||||
print(f"Error during model conversion: {exception}")
|
||||
raise
|
||||
|
||||
|
||||
def test_inference(model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> None:
|
||||
"""Run test inference"""
|
||||
# Test prompts
|
||||
test_prompts = [
|
||||
"The quick brown fox",
|
||||
]
|
||||
|
||||
for prompt in test_prompts:
|
||||
try:
|
||||
# Tokenize
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||
|
||||
# Generate
|
||||
from time import time
|
||||
|
||||
start = time()
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=20,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
# temperature=0.7,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
# use_cache=True,
|
||||
)
|
||||
elasped = time() - start
|
||||
print(f"generation time: {elasped}s")
|
||||
|
||||
# Decode
|
||||
print(outputs)
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
print(f"\nPrompt: {prompt}")
|
||||
print(f"Generated: {generated_text}\n")
|
||||
|
||||
except Exception as exception:
|
||||
print(f"Error during inference: {str(exception)}")
|
||||
raise
|
||||
|
||||
|
||||
def save_converted_model(model: AutoModelForCausalLM, output_dir: str) -> None:
|
||||
"""Save the converted model"""
|
||||
print(f"Saving converted model to {output_dir}")
|
||||
model.save_pretrained(output_dir)
|
||||
|
||||
|
||||
def main():
|
||||
# Configuration
|
||||
model_name = "HuggingFaceTB/SmolLM2-135M"
|
||||
# model_name = "openlm-research/open_llama_3b_v2"
|
||||
output_dir = "./converted_model"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using device: {device}")
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
model, tokenizer = setup_model(model_name, device)
|
||||
|
||||
# Print original model info
|
||||
print("Original model config:")
|
||||
print(f"\t- Hidden size: {model.config.hidden_size}")
|
||||
print(f"\t- Number of attention heads: {model.config.num_attention_heads}")
|
||||
|
||||
# Test the original model
|
||||
test_inference(model, tokenizer)
|
||||
|
||||
# Convert to differential attention
|
||||
model = convert_to_diff_attention(model)
|
||||
model.to(model.device)
|
||||
print("Model conversion completed")
|
||||
|
||||
# Test the converted model
|
||||
test_inference(model, tokenizer)
|
||||
|
||||
# Save converted model
|
||||
save_converted_model(model, output_dir)
|
||||
|
||||
except Exception as exception:
|
||||
print(f"Error during test: {str(exception)}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,20 +1,81 @@
|
||||
"""Differential attention conversion logic for a huggingface pre-trained model."""
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaSdpaAttention
|
||||
from transformers.models.mistral.modeling_mistral import MistralAttention
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
|
||||
|
||||
from .multihead_diffattn import DifferentialAttention
|
||||
from .multihead_diffattn import (
|
||||
LlamaDifferentialAttention,
|
||||
LlamaDifferentialSdpaAttention,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def copy_attention_weights(
|
||||
old_attn: Union[LlamaAttention, LlamaSdpaAttention],
|
||||
new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention],
|
||||
zero_init: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Copy weights from old attention layer to new differential attention layer.
|
||||
Copies old weights to Q1 and K1, zeros out Q2 and K2 for exact equivalence
|
||||
to original attention mechanism.
|
||||
"""
|
||||
# For Q projection (Q1 and Q2)
|
||||
new_q = torch.empty_like(new_attn.q_proj.weight.data)
|
||||
new_q[: new_attn.hidden_size] = old_attn.q_proj.weight.data # Q1
|
||||
if zero_init:
|
||||
new_q[new_attn.hidden_size :] = 0
|
||||
else:
|
||||
nn.init.normal_(new_q[new_attn.hidden_size :], mean=0, std=0.1)
|
||||
new_attn.q_proj.weight.data.copy_(new_q)
|
||||
|
||||
# For K projection (K1 and K2)
|
||||
old_kv_size = old_attn.k_proj.weight.data.size(0) # Size for 3 heads
|
||||
new_k = torch.empty_like(new_attn.k_proj.weight.data)
|
||||
new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1
|
||||
if zero_init:
|
||||
new_k[old_kv_size:] = 0
|
||||
else:
|
||||
nn.init.normal_(new_k[old_kv_size:], mean=0, std=0.1)
|
||||
new_attn.k_proj.weight.data.copy_(new_k)
|
||||
|
||||
# For V projection (single V)
|
||||
new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data)
|
||||
|
||||
# Output projection remains the same
|
||||
new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data)
|
||||
|
||||
# Zero out lambda parameters for exact equivalence
|
||||
if zero_init:
|
||||
nn.init.zeros_(new_attn.lambda_q1)
|
||||
nn.init.zeros_(new_attn.lambda_k1)
|
||||
nn.init.zeros_(new_attn.lambda_q2)
|
||||
nn.init.zeros_(new_attn.lambda_k2)
|
||||
new_attn.lambda_init = 0.0
|
||||
|
||||
logger.debug(
|
||||
"Copied positive attention weights from %s to %s",
|
||||
type(old_attn).__name__,
|
||||
type(new_attn).__name__,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
|
||||
"""Convert a pre-trained model's attention layers to differential attention"""
|
||||
attention_patterns = (LlamaAttention, MistralAttention, MixtralAttention)
|
||||
attention_patterns = (
|
||||
LlamaAttention,
|
||||
LlamaSdpaAttention,
|
||||
MistralAttention,
|
||||
MixtralAttention,
|
||||
)
|
||||
layer_idx = 0
|
||||
|
||||
# Get model dtype from existing weights
|
||||
@@ -29,13 +90,22 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
|
||||
layer_type = type(child).__name__
|
||||
logger.info(f"Converting attention layer {layer_idx}: {layer_type}")
|
||||
|
||||
# Choose appropriate differential attention class
|
||||
if isinstance(child, LlamaSdpaAttention):
|
||||
attention_class = LlamaDifferentialSdpaAttention
|
||||
else:
|
||||
attention_class = LlamaDifferentialAttention
|
||||
|
||||
# Create new diff attn layer
|
||||
new_attention = DifferentialAttention(
|
||||
new_attention = attention_class(
|
||||
config=module.config if hasattr(module, "config") else model.config,
|
||||
layer_idx=layer_idx,
|
||||
dtype=model_dtype,
|
||||
)
|
||||
|
||||
# Copy weights from old attention to new attention
|
||||
copy_attention_weights(child, new_attention)
|
||||
|
||||
# Replace the layer
|
||||
setattr(module, name, new_attention)
|
||||
layer_idx += 1
|
||||
|
||||
@@ -6,9 +6,9 @@ from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaRotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
@@ -34,7 +34,7 @@ def lambda_init_fn(depth):
|
||||
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
||||
|
||||
|
||||
class DifferentialAttention(nn.Module):
|
||||
class LlamaDifferentialAttention(nn.Module):
|
||||
"""Differential Attention implementation as described in the Diff Transformer paper.
|
||||
|
||||
This implements a modified attention mechanism that computes the difference between
|
||||
@@ -54,7 +54,6 @@ class DifferentialAttention(nn.Module):
|
||||
config: Model configuration object containing hidden size, number of heads etc.
|
||||
layer_idx: Index of this layer in the transformer stack
|
||||
dtype: Data type for the layer parameters
|
||||
is_causal: Whether to use causal (masked) attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -62,43 +61,52 @@ class DifferentialAttention(nn.Module):
|
||||
config: Any,
|
||||
layer_idx: int,
|
||||
dtype: torch.dtype,
|
||||
is_causal: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
# Base model dimensions
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.is_causal = is_causal
|
||||
# self.head_dim = self.hidden_size // self.num_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads // 2
|
||||
self.num_key_value_heads = getattr(
|
||||
config, "num_key_value_heads", self.num_heads
|
||||
)
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.scaling = (self.head_dim) ** -0.5
|
||||
self.base_num_heads = config.num_attention_heads
|
||||
self.base_num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
|
||||
# Initialize projections with correct dtype
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.layer_idx = layer_idx
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
# For Q1 and Q2
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.hidden_size, bias=False, dtype=dtype
|
||||
self.hidden_size,
|
||||
self.hidden_size * 2,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# For K1 and K2
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size // self.num_key_value_groups,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size // self.num_key_value_groups,
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Single V projection
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Output projection
|
||||
self.o_proj = nn.Linear(
|
||||
self.hidden_size, self.hidden_size, bias=False, dtype=dtype
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Initialize differential attention parameters
|
||||
@@ -116,7 +124,6 @@ class DifferentialAttention(nn.Module):
|
||||
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
|
||||
)
|
||||
|
||||
self.subln = RMSNorm(2 * self.head_dim, eps=1e-5)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||
|
||||
def forward(
|
||||
@@ -126,6 +133,7 @@ class DifferentialAttention(nn.Module):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False, # pylint: disable=unused-argument
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
@@ -134,97 +142,261 @@ class DifferentialAttention(nn.Module):
|
||||
Optional[torch.Tensor],
|
||||
Optional[tuple[torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Project queries, keys and values
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
# Project to Q1,Q2 and K1,K2
|
||||
qp = self.q_proj(hidden_states)
|
||||
kp = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
# Reshape for attention
|
||||
q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(bsz, tgt_len, 2 * self.num_key_value_heads, self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
v = v.view(bsz, tgt_len, self.num_key_value_heads, 2 * self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
# Split into Q1,Q2 and K1,K2
|
||||
q1, q2 = qp.chunk(2, dim=-1)
|
||||
k1, k2 = kp.chunk(2, dim=-1)
|
||||
|
||||
# Generate or unpack cos, sin for rotary positional embeddings
|
||||
# Reshape Q1,Q2 for attention
|
||||
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Reshape K1,K2 for attention
|
||||
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Reshape V
|
||||
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary embeddings
|
||||
if position_embeddings is None:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
0, tgt_len, dtype=torch.long, device=q.device
|
||||
)
|
||||
cos, sin = self.rotary_emb(q, position_ids)
|
||||
position_ids = torch.arange(q_len, device=q1.device)
|
||||
cos, sin = self.rotary_emb(q1, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
# Need to adjust cos, sin to match the halved head_dim
|
||||
cos = cos[..., : self.head_dim]
|
||||
sin = sin[..., : self.head_dim]
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
|
||||
# Update cache and get back concatenated states
|
||||
k = torch.stack([k1, k2], dim=1)
|
||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||
k1, k2 = k.unbind(dim=1)
|
||||
|
||||
# Prepare for attention
|
||||
k = repeat_kv(k, self.num_key_value_groups)
|
||||
v = repeat_kv(v, self.num_key_value_groups)
|
||||
# Repeat KV heads to match Q heads
|
||||
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
||||
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
||||
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
||||
|
||||
# Scale query
|
||||
q = q * self.scaling
|
||||
# Calculate attention scores for both parts
|
||||
# NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor
|
||||
# instead of sqrt(head_dim). This could be set on the class as `self.scaling`.
|
||||
attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(
|
||||
self.head_dim
|
||||
)
|
||||
attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(
|
||||
self.head_dim
|
||||
)
|
||||
|
||||
# Calculate attention scores
|
||||
attn_weights = torch.matmul(q, k.transpose(-1, -2))
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
|
||||
attn_weights1 = attn_weights1 + causal_mask
|
||||
attn_weights2 = attn_weights2 + causal_mask
|
||||
|
||||
# Apply causal mask
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.triu(
|
||||
torch.full((tgt_len, tgt_len), float("-inf"), device=q.device),
|
||||
diagonal=1,
|
||||
).type_as(attn_weights)
|
||||
attn_weights = torch.nan_to_num(attn_weights)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# Apply softmax
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
|
||||
attn_weights
|
||||
# Apply softmax separately as per paper
|
||||
attn_weights1 = F.softmax(attn_weights1, dim=-1, dtype=torch.float32).type_as(
|
||||
attn_weights1
|
||||
)
|
||||
attn_weights2 = F.softmax(attn_weights2, dim=-1, dtype=torch.float32).type_as(
|
||||
attn_weights2
|
||||
)
|
||||
attn_weights1 = F.dropout(
|
||||
attn_weights1, p=self.attention_dropout, training=self.training
|
||||
)
|
||||
attn_weights2 = F.dropout(
|
||||
attn_weights2, p=self.attention_dropout, training=self.training
|
||||
)
|
||||
|
||||
# Calculate lambda
|
||||
lambda_1 = torch.exp(
|
||||
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
||||
).type_as(q)
|
||||
).type_as(q1)
|
||||
lambda_2 = torch.exp(
|
||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||
).type_as(q)
|
||||
).type_as(q1)
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
# Apply differential attention
|
||||
attn_weights = attn_weights.view(
|
||||
bsz, self.num_heads, 2, -1, attn_weights.size(-1)
|
||||
)
|
||||
attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
|
||||
# Compute differential attention (following paper's formula)
|
||||
attn_weights = attn_weights1 - lambda_full * attn_weights2
|
||||
|
||||
# Apply attention to values
|
||||
# Apply attention weights to values
|
||||
attn = torch.matmul(attn_weights, v)
|
||||
|
||||
# Apply sublayer norm
|
||||
attn = self.subln(attn).type_as(attn)
|
||||
# Apply sublayer norm and scaling
|
||||
# NOTE(Dan): The differential transformers paper applies sublayer normalization at this
|
||||
# point, but this is typically done outside of the attention layer. It would look something
|
||||
# like: `attn = self.subln(attn).type_as(attn)`, using `LlamaRMSNorm` or similar.
|
||||
attn = attn * (1 - self.lambda_init)
|
||||
|
||||
# Reshape and project output
|
||||
attn = attn.transpose(1, 2).reshape(
|
||||
bsz, tgt_len, self.num_heads * 2 * self.head_dim
|
||||
)
|
||||
# Reshape to output
|
||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
||||
attn = self.o_proj(attn)
|
||||
|
||||
# Return in exact format expected by LLaMA
|
||||
if output_attentions:
|
||||
return attn, attn_weights, past_key_value
|
||||
return attn, None, past_key_value
|
||||
|
||||
|
||||
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
|
||||
"""Differential Attention implementation as described in the Diff Transformer paper.
|
||||
This implements the same logic as `LlamaDifferentialAttention`, but uses
|
||||
`scaled_dot_product_attention` instead of "manually" computing it under the hood.
|
||||
|
||||
This implements a modified attention mechanism that computes the difference between
|
||||
two attention patterns, scaled by learned lambda parameters. The mechanism helps
|
||||
reduce noise in the attention weights for irrelevant / less relevant tokens.
|
||||
|
||||
Key components:
|
||||
- Split head dimension for differential computation
|
||||
- Learned lambda parameters that control attention scaling
|
||||
- Sublayer normalization on the attention output
|
||||
|
||||
See:
|
||||
- https://arxiv.org/abs/2410.05258
|
||||
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
|
||||
|
||||
Args:
|
||||
config: Model configuration object containing hidden size, number of heads etc.
|
||||
layer_idx: Index of this layer in the transformer stack
|
||||
dtype: Data type for the layer parameters
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size]
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[tuple[torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
if output_attentions:
|
||||
transformers.logger.warning_once(
|
||||
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Project to Q1,Q2 and K1,K2
|
||||
qp = self.q_proj(hidden_states)
|
||||
kp = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
# Split into Q1,Q2 and K1,K2
|
||||
q1, q2 = qp.chunk(2, dim=-1)
|
||||
k1, k2 = kp.chunk(2, dim=-1)
|
||||
|
||||
# Reshape Q1,Q2 for attention
|
||||
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
|
||||
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
|
||||
# Reshape K1,K2 for attention
|
||||
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
# Reshape V
|
||||
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary embeddings
|
||||
if position_embeddings is None:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(q_len, device=q1.device)
|
||||
cos, sin = self.rotary_emb(q1, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
k = torch.stack([k1, k2], dim=1)
|
||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||
k1, k2 = k.unbind(dim=1)
|
||||
|
||||
# Repeat KV heads to match Q heads
|
||||
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
||||
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
||||
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
||||
|
||||
causal_mask = None
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask
|
||||
causal_mask = causal_mask[:, :, :, : k1.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend requires contiguous inputs on CUDA
|
||||
if q1.device.type == "cuda" and causal_mask is not None:
|
||||
q1, q2 = q1.contiguous(), q2.contiguous()
|
||||
k1, k2 = k1.contiguous(), k2.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
# Calculate attention using SDPA
|
||||
is_causal = attention_mask is None and q_len > 1
|
||||
|
||||
attn_output1 = F.scaled_dot_product_attention(
|
||||
q1,
|
||||
k1,
|
||||
v,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
attn_output2 = F.scaled_dot_product_attention(
|
||||
q2,
|
||||
k2,
|
||||
v,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
# Calculate lambda
|
||||
lambda_1 = torch.exp(
|
||||
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_2 = torch.exp(
|
||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
# Combine the attention outputs
|
||||
attn = attn_output1 - lambda_full * attn_output2
|
||||
|
||||
# Apply sublayer norm and scaling
|
||||
attn = attn * (1 - self.lambda_init)
|
||||
|
||||
# Reshape to output
|
||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
||||
attn = self.o_proj(attn)
|
||||
|
||||
if output_attentions:
|
||||
return (
|
||||
attn,
|
||||
None,
|
||||
past_key_value,
|
||||
) # Note: can't return attn_weights with SDPA
|
||||
return attn, None, past_key_value
|
||||
|
||||
Reference in New Issue
Block a user