Adding script for doing conversion; fixes and updates
This commit is contained in:
127
scripts/convert_diff_transformer.py
Normal file
127
scripts/convert_diff_transformer.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""Test conversion of transformers model attention to differential attention."""
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
PreTrainedModel,
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model(
|
||||||
|
model_name: str, device: str = "cuda"
|
||||||
|
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||||
|
"""Load model and tokenizer"""
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map=device,
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def convert_model_attention(model: AutoModelForCausalLM) -> AutoModelForCausalLM:
|
||||||
|
"""Convert model to use differential attention"""
|
||||||
|
try:
|
||||||
|
model = convert_to_diff_attention(model)
|
||||||
|
return model
|
||||||
|
except Exception as exception:
|
||||||
|
print(f"Error during model conversion: {exception}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_inference(model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> None:
|
||||||
|
"""Run test inference"""
|
||||||
|
# Test prompts
|
||||||
|
test_prompts = [
|
||||||
|
"The quick brown fox",
|
||||||
|
]
|
||||||
|
|
||||||
|
for prompt in test_prompts:
|
||||||
|
try:
|
||||||
|
# Tokenize
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
start = time()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=20,
|
||||||
|
num_beams=1,
|
||||||
|
do_sample=False,
|
||||||
|
# temperature=0.7,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
use_cache=False,
|
||||||
|
# use_cache=True,
|
||||||
|
)
|
||||||
|
elasped = time() - start
|
||||||
|
print(f"generation time: {elasped}s")
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
print(outputs)
|
||||||
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
print(f"\nPrompt: {prompt}")
|
||||||
|
print(f"Generated: {generated_text}\n")
|
||||||
|
|
||||||
|
except Exception as exception:
|
||||||
|
print(f"Error during inference: {str(exception)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def save_converted_model(model: AutoModelForCausalLM, output_dir: str) -> None:
|
||||||
|
"""Save the converted model"""
|
||||||
|
print(f"Saving converted model to {output_dir}")
|
||||||
|
model.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Configuration
|
||||||
|
model_name = "HuggingFaceTB/SmolLM2-135M"
|
||||||
|
# model_name = "openlm-research/open_llama_3b_v2"
|
||||||
|
output_dir = "./converted_model"
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load model and tokenizer
|
||||||
|
model, tokenizer = setup_model(model_name, device)
|
||||||
|
|
||||||
|
# Print original model info
|
||||||
|
print("Original model config:")
|
||||||
|
print(f"\t- Hidden size: {model.config.hidden_size}")
|
||||||
|
print(f"\t- Number of attention heads: {model.config.num_attention_heads}")
|
||||||
|
|
||||||
|
# Test the original model
|
||||||
|
test_inference(model, tokenizer)
|
||||||
|
|
||||||
|
# Convert to differential attention
|
||||||
|
model = convert_to_diff_attention(model)
|
||||||
|
model.to(model.device)
|
||||||
|
print("Model conversion completed")
|
||||||
|
|
||||||
|
# Test the converted model
|
||||||
|
test_inference(model, tokenizer)
|
||||||
|
|
||||||
|
# Save converted model
|
||||||
|
save_converted_model(model, output_dir)
|
||||||
|
|
||||||
|
except Exception as exception:
|
||||||
|
print(f"Error during test: {str(exception)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,20 +1,81 @@
|
|||||||
"""Differential attention conversion logic for a huggingface pre-trained model."""
|
"""Differential attention conversion logic for a huggingface pre-trained model."""
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaSdpaAttention
|
||||||
from transformers.models.mistral.modeling_mistral import MistralAttention
|
from transformers.models.mistral.modeling_mistral import MistralAttention
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
|
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
|
||||||
|
|
||||||
from .multihead_diffattn import DifferentialAttention
|
from .multihead_diffattn import (
|
||||||
|
LlamaDifferentialAttention,
|
||||||
|
LlamaDifferentialSdpaAttention,
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_attention_weights(
|
||||||
|
old_attn: Union[LlamaAttention, LlamaSdpaAttention],
|
||||||
|
new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention],
|
||||||
|
zero_init: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Copy weights from old attention layer to new differential attention layer.
|
||||||
|
Copies old weights to Q1 and K1, zeros out Q2 and K2 for exact equivalence
|
||||||
|
to original attention mechanism.
|
||||||
|
"""
|
||||||
|
# For Q projection (Q1 and Q2)
|
||||||
|
new_q = torch.empty_like(new_attn.q_proj.weight.data)
|
||||||
|
new_q[: new_attn.hidden_size] = old_attn.q_proj.weight.data # Q1
|
||||||
|
if zero_init:
|
||||||
|
new_q[new_attn.hidden_size :] = 0
|
||||||
|
else:
|
||||||
|
nn.init.normal_(new_q[new_attn.hidden_size :], mean=0, std=0.1)
|
||||||
|
new_attn.q_proj.weight.data.copy_(new_q)
|
||||||
|
|
||||||
|
# For K projection (K1 and K2)
|
||||||
|
old_kv_size = old_attn.k_proj.weight.data.size(0) # Size for 3 heads
|
||||||
|
new_k = torch.empty_like(new_attn.k_proj.weight.data)
|
||||||
|
new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1
|
||||||
|
if zero_init:
|
||||||
|
new_k[old_kv_size:] = 0
|
||||||
|
else:
|
||||||
|
nn.init.normal_(new_k[old_kv_size:], mean=0, std=0.1)
|
||||||
|
new_attn.k_proj.weight.data.copy_(new_k)
|
||||||
|
|
||||||
|
# For V projection (single V)
|
||||||
|
new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data)
|
||||||
|
|
||||||
|
# Output projection remains the same
|
||||||
|
new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data)
|
||||||
|
|
||||||
|
# Zero out lambda parameters for exact equivalence
|
||||||
|
if zero_init:
|
||||||
|
nn.init.zeros_(new_attn.lambda_q1)
|
||||||
|
nn.init.zeros_(new_attn.lambda_k1)
|
||||||
|
nn.init.zeros_(new_attn.lambda_q2)
|
||||||
|
nn.init.zeros_(new_attn.lambda_k2)
|
||||||
|
new_attn.lambda_init = 0.0
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Copied positive attention weights from %s to %s",
|
||||||
|
type(old_attn).__name__,
|
||||||
|
type(new_attn).__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
|
def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
|
||||||
"""Convert a pre-trained model's attention layers to differential attention"""
|
"""Convert a pre-trained model's attention layers to differential attention"""
|
||||||
attention_patterns = (LlamaAttention, MistralAttention, MixtralAttention)
|
attention_patterns = (
|
||||||
|
LlamaAttention,
|
||||||
|
LlamaSdpaAttention,
|
||||||
|
MistralAttention,
|
||||||
|
MixtralAttention,
|
||||||
|
)
|
||||||
layer_idx = 0
|
layer_idx = 0
|
||||||
|
|
||||||
# Get model dtype from existing weights
|
# Get model dtype from existing weights
|
||||||
@@ -29,13 +90,22 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
|
|||||||
layer_type = type(child).__name__
|
layer_type = type(child).__name__
|
||||||
logger.info(f"Converting attention layer {layer_idx}: {layer_type}")
|
logger.info(f"Converting attention layer {layer_idx}: {layer_type}")
|
||||||
|
|
||||||
|
# Choose appropriate differential attention class
|
||||||
|
if isinstance(child, LlamaSdpaAttention):
|
||||||
|
attention_class = LlamaDifferentialSdpaAttention
|
||||||
|
else:
|
||||||
|
attention_class = LlamaDifferentialAttention
|
||||||
|
|
||||||
# Create new diff attn layer
|
# Create new diff attn layer
|
||||||
new_attention = DifferentialAttention(
|
new_attention = attention_class(
|
||||||
config=module.config if hasattr(module, "config") else model.config,
|
config=module.config if hasattr(module, "config") else model.config,
|
||||||
layer_idx=layer_idx,
|
layer_idx=layer_idx,
|
||||||
dtype=model_dtype,
|
dtype=model_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copy weights from old attention to new attention
|
||||||
|
copy_attention_weights(child, new_attention)
|
||||||
|
|
||||||
# Replace the layer
|
# Replace the layer
|
||||||
setattr(module, name, new_attention)
|
setattr(module, name, new_attention)
|
||||||
layer_idx += 1
|
layer_idx += 1
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ from typing import Any, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import transformers
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaRotaryEmbedding,
|
LlamaRotaryEmbedding,
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
@@ -34,7 +34,7 @@ def lambda_init_fn(depth):
|
|||||||
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
||||||
|
|
||||||
|
|
||||||
class DifferentialAttention(nn.Module):
|
class LlamaDifferentialAttention(nn.Module):
|
||||||
"""Differential Attention implementation as described in the Diff Transformer paper.
|
"""Differential Attention implementation as described in the Diff Transformer paper.
|
||||||
|
|
||||||
This implements a modified attention mechanism that computes the difference between
|
This implements a modified attention mechanism that computes the difference between
|
||||||
@@ -54,7 +54,6 @@ class DifferentialAttention(nn.Module):
|
|||||||
config: Model configuration object containing hidden size, number of heads etc.
|
config: Model configuration object containing hidden size, number of heads etc.
|
||||||
layer_idx: Index of this layer in the transformer stack
|
layer_idx: Index of this layer in the transformer stack
|
||||||
dtype: Data type for the layer parameters
|
dtype: Data type for the layer parameters
|
||||||
is_causal: Whether to use causal (masked) attention
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -62,43 +61,52 @@ class DifferentialAttention(nn.Module):
|
|||||||
config: Any,
|
config: Any,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
is_causal: bool = True,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
# Base model dimensions
|
||||||
self.layer_idx = layer_idx
|
self.attention_dropout = config.attention_dropout
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.base_num_heads = config.num_attention_heads
|
||||||
self.is_causal = is_causal
|
self.base_num_kv_heads = config.num_key_value_heads
|
||||||
# self.head_dim = self.hidden_size // self.num_heads
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||||
self.head_dim = self.hidden_size // self.num_heads // 2
|
|
||||||
self.num_key_value_heads = getattr(
|
|
||||||
config, "num_key_value_heads", self.num_heads
|
|
||||||
)
|
|
||||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
|
||||||
self.scaling = (self.head_dim) ** -0.5
|
|
||||||
|
|
||||||
# Initialize projections with correct dtype
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
self.is_causal = True
|
||||||
|
|
||||||
|
# For Q1 and Q2
|
||||||
self.q_proj = nn.Linear(
|
self.q_proj = nn.Linear(
|
||||||
self.hidden_size, self.hidden_size, bias=False, dtype=dtype
|
self.hidden_size,
|
||||||
|
self.hidden_size * 2,
|
||||||
|
bias=False,
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For K1 and K2
|
||||||
self.k_proj = nn.Linear(
|
self.k_proj = nn.Linear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.hidden_size // self.num_key_value_groups,
|
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
|
||||||
bias=False,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
self.v_proj = nn.Linear(
|
|
||||||
self.hidden_size,
|
|
||||||
self.hidden_size // self.num_key_value_groups,
|
|
||||||
bias=False,
|
bias=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Single V projection
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
||||||
|
bias=False,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output projection
|
||||||
self.o_proj = nn.Linear(
|
self.o_proj = nn.Linear(
|
||||||
self.hidden_size, self.hidden_size, bias=False, dtype=dtype
|
self.hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize differential attention parameters
|
# Initialize differential attention parameters
|
||||||
@@ -116,7 +124,6 @@ class DifferentialAttention(nn.Module):
|
|||||||
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
|
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.subln = RMSNorm(2 * self.head_dim, eps=1e-5)
|
|
||||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -126,6 +133,7 @@ class DifferentialAttention(nn.Module):
|
|||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False, # pylint: disable=unused-argument
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
@@ -134,97 +142,261 @@ class DifferentialAttention(nn.Module):
|
|||||||
Optional[torch.Tensor],
|
Optional[torch.Tensor],
|
||||||
Optional[tuple[torch.Tensor, torch.Tensor]],
|
Optional[tuple[torch.Tensor, torch.Tensor]],
|
||||||
]:
|
]:
|
||||||
bsz, tgt_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
# Project queries, keys and values
|
# Project to Q1,Q2 and K1,K2
|
||||||
q = self.q_proj(hidden_states)
|
qp = self.q_proj(hidden_states)
|
||||||
k = self.k_proj(hidden_states)
|
kp = self.k_proj(hidden_states)
|
||||||
v = self.v_proj(hidden_states)
|
v = self.v_proj(hidden_states)
|
||||||
|
|
||||||
# Reshape for attention
|
# Split into Q1,Q2 and K1,K2
|
||||||
q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
q1, q2 = qp.chunk(2, dim=-1)
|
||||||
k = k.view(bsz, tgt_len, 2 * self.num_key_value_heads, self.head_dim).transpose(
|
k1, k2 = kp.chunk(2, dim=-1)
|
||||||
1, 2
|
|
||||||
)
|
|
||||||
v = v.view(bsz, tgt_len, self.num_key_value_heads, 2 * self.head_dim).transpose(
|
|
||||||
1, 2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate or unpack cos, sin for rotary positional embeddings
|
# Reshape Q1,Q2 for attention
|
||||||
|
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Reshape K1,K2 for attention
|
||||||
|
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Reshape V
|
||||||
|
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Apply rotary embeddings
|
||||||
if position_embeddings is None:
|
if position_embeddings is None:
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(q_len, device=q1.device)
|
||||||
0, tgt_len, dtype=torch.long, device=q.device
|
cos, sin = self.rotary_emb(q1, position_ids)
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(q, position_ids)
|
|
||||||
else:
|
else:
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
# Need to adjust cos, sin to match the halved head_dim
|
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||||
cos = cos[..., : self.head_dim]
|
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||||
sin = sin[..., : self.head_dim]
|
|
||||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
k = torch.stack([k1, k2], dim=1)
|
||||||
# Update cache and get back concatenated states
|
|
||||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||||
|
k1, k2 = k.unbind(dim=1)
|
||||||
|
|
||||||
# Prepare for attention
|
# Repeat KV heads to match Q heads
|
||||||
k = repeat_kv(k, self.num_key_value_groups)
|
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
||||||
v = repeat_kv(v, self.num_key_value_groups)
|
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
||||||
|
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
||||||
|
|
||||||
# Scale query
|
# Calculate attention scores for both parts
|
||||||
q = q * self.scaling
|
# NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor
|
||||||
|
# instead of sqrt(head_dim). This could be set on the class as `self.scaling`.
|
||||||
|
attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(
|
||||||
|
self.head_dim
|
||||||
|
)
|
||||||
|
attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(
|
||||||
|
self.head_dim
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate attention scores
|
if attention_mask is not None:
|
||||||
attn_weights = torch.matmul(q, k.transpose(-1, -2))
|
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
|
||||||
|
attn_weights1 = attn_weights1 + causal_mask
|
||||||
|
attn_weights2 = attn_weights2 + causal_mask
|
||||||
|
|
||||||
# Apply causal mask
|
# Apply softmax separately as per paper
|
||||||
if attention_mask is None:
|
attn_weights1 = F.softmax(attn_weights1, dim=-1, dtype=torch.float32).type_as(
|
||||||
attention_mask = torch.triu(
|
attn_weights1
|
||||||
torch.full((tgt_len, tgt_len), float("-inf"), device=q.device),
|
)
|
||||||
diagonal=1,
|
attn_weights2 = F.softmax(attn_weights2, dim=-1, dtype=torch.float32).type_as(
|
||||||
).type_as(attn_weights)
|
attn_weights2
|
||||||
attn_weights = torch.nan_to_num(attn_weights)
|
)
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights1 = F.dropout(
|
||||||
|
attn_weights1, p=self.attention_dropout, training=self.training
|
||||||
# Apply softmax
|
)
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
|
attn_weights2 = F.dropout(
|
||||||
attn_weights
|
attn_weights2, p=self.attention_dropout, training=self.training
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate lambda
|
# Calculate lambda
|
||||||
lambda_1 = torch.exp(
|
lambda_1 = torch.exp(
|
||||||
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
||||||
).type_as(q)
|
).type_as(q1)
|
||||||
lambda_2 = torch.exp(
|
lambda_2 = torch.exp(
|
||||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||||
).type_as(q)
|
).type_as(q1)
|
||||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||||
|
|
||||||
# Apply differential attention
|
# Compute differential attention (following paper's formula)
|
||||||
attn_weights = attn_weights.view(
|
attn_weights = attn_weights1 - lambda_full * attn_weights2
|
||||||
bsz, self.num_heads, 2, -1, attn_weights.size(-1)
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
|
|
||||||
|
|
||||||
# Apply attention to values
|
# Apply attention weights to values
|
||||||
attn = torch.matmul(attn_weights, v)
|
attn = torch.matmul(attn_weights, v)
|
||||||
|
|
||||||
# Apply sublayer norm
|
# Apply sublayer norm and scaling
|
||||||
attn = self.subln(attn).type_as(attn)
|
# NOTE(Dan): The differential transformers paper applies sublayer normalization at this
|
||||||
|
# point, but this is typically done outside of the attention layer. It would look something
|
||||||
|
# like: `attn = self.subln(attn).type_as(attn)`, using `LlamaRMSNorm` or similar.
|
||||||
attn = attn * (1 - self.lambda_init)
|
attn = attn * (1 - self.lambda_init)
|
||||||
|
|
||||||
# Reshape and project output
|
# Reshape to output
|
||||||
attn = attn.transpose(1, 2).reshape(
|
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
||||||
bsz, tgt_len, self.num_heads * 2 * self.head_dim
|
|
||||||
)
|
|
||||||
attn = self.o_proj(attn)
|
attn = self.o_proj(attn)
|
||||||
|
|
||||||
# Return in exact format expected by LLaMA
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
return attn, attn_weights, past_key_value
|
return attn, attn_weights, past_key_value
|
||||||
return attn, None, past_key_value
|
return attn, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
|
||||||
|
"""Differential Attention implementation as described in the Diff Transformer paper.
|
||||||
|
This implements the same logic as `LlamaDifferentialAttention`, but uses
|
||||||
|
`scaled_dot_product_attention` instead of "manually" computing it under the hood.
|
||||||
|
|
||||||
|
This implements a modified attention mechanism that computes the difference between
|
||||||
|
two attention patterns, scaled by learned lambda parameters. The mechanism helps
|
||||||
|
reduce noise in the attention weights for irrelevant / less relevant tokens.
|
||||||
|
|
||||||
|
Key components:
|
||||||
|
- Split head dimension for differential computation
|
||||||
|
- Learned lambda parameters that control attention scaling
|
||||||
|
- Sublayer normalization on the attention output
|
||||||
|
|
||||||
|
See:
|
||||||
|
- https://arxiv.org/abs/2410.05258
|
||||||
|
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Model configuration object containing hidden size, number of heads etc.
|
||||||
|
layer_idx: Index of this layer in the transformer stack
|
||||||
|
dtype: Data type for the layer parameters
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size]
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
) -> tuple[
|
||||||
|
torch.Tensor,
|
||||||
|
Optional[torch.Tensor],
|
||||||
|
Optional[tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
]:
|
||||||
|
if output_attentions:
|
||||||
|
transformers.logger.warning_once(
|
||||||
|
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||||
|
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# Project to Q1,Q2 and K1,K2
|
||||||
|
qp = self.q_proj(hidden_states)
|
||||||
|
kp = self.k_proj(hidden_states)
|
||||||
|
v = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Split into Q1,Q2 and K1,K2
|
||||||
|
q1, q2 = qp.chunk(2, dim=-1)
|
||||||
|
k1, k2 = kp.chunk(2, dim=-1)
|
||||||
|
|
||||||
|
# Reshape Q1,Q2 for attention
|
||||||
|
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
# Reshape K1,K2 for attention
|
||||||
|
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||||
|
# Reshape V
|
||||||
|
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Apply rotary embeddings
|
||||||
|
if position_embeddings is None:
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(q_len, device=q1.device)
|
||||||
|
cos, sin = self.rotary_emb(q1, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
|
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||||
|
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
k = torch.stack([k1, k2], dim=1)
|
||||||
|
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||||
|
k1, k2 = k.unbind(dim=1)
|
||||||
|
|
||||||
|
# Repeat KV heads to match Q heads
|
||||||
|
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
||||||
|
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
||||||
|
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
||||||
|
|
||||||
|
causal_mask = None
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = attention_mask
|
||||||
|
causal_mask = causal_mask[:, :, :, : k1.shape[-2]]
|
||||||
|
|
||||||
|
# SDPA with memory-efficient backend requires contiguous inputs on CUDA
|
||||||
|
if q1.device.type == "cuda" and causal_mask is not None:
|
||||||
|
q1, q2 = q1.contiguous(), q2.contiguous()
|
||||||
|
k1, k2 = k1.contiguous(), k2.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
|
||||||
|
# Calculate attention using SDPA
|
||||||
|
is_causal = attention_mask is None and q_len > 1
|
||||||
|
|
||||||
|
attn_output1 = F.scaled_dot_product_attention(
|
||||||
|
q1,
|
||||||
|
k1,
|
||||||
|
v,
|
||||||
|
attn_mask=causal_mask,
|
||||||
|
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
attn_output2 = F.scaled_dot_product_attention(
|
||||||
|
q2,
|
||||||
|
k2,
|
||||||
|
v,
|
||||||
|
attn_mask=causal_mask,
|
||||||
|
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate lambda
|
||||||
|
lambda_1 = torch.exp(
|
||||||
|
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
||||||
|
).type_as(q1)
|
||||||
|
lambda_2 = torch.exp(
|
||||||
|
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||||
|
).type_as(q1)
|
||||||
|
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||||
|
|
||||||
|
# Combine the attention outputs
|
||||||
|
attn = attn_output1 - lambda_full * attn_output2
|
||||||
|
|
||||||
|
# Apply sublayer norm and scaling
|
||||||
|
attn = attn * (1 - self.lambda_init)
|
||||||
|
|
||||||
|
# Reshape to output
|
||||||
|
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
||||||
|
attn = self.o_proj(attn)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
return (
|
||||||
|
attn,
|
||||||
|
None,
|
||||||
|
past_key_value,
|
||||||
|
) # Note: can't return attn_weights with SDPA
|
||||||
|
return attn, None, past_key_value
|
||||||
|
|||||||
Reference in New Issue
Block a user