Adding script for doing conversion; fixes and updates

This commit is contained in:
Dan Saunders
2024-12-11 21:35:47 -05:00
parent 13cdffa91f
commit 7be0d7496c
3 changed files with 458 additions and 89 deletions

View File

@@ -0,0 +1,127 @@
"""Test conversion of transformers model attention to differential attention."""
from typing import Tuple
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
)
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention
def setup_model(
model_name: str, device: str = "cuda"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""Load model and tokenizer"""
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map=device,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
def convert_model_attention(model: AutoModelForCausalLM) -> AutoModelForCausalLM:
"""Convert model to use differential attention"""
try:
model = convert_to_diff_attention(model)
return model
except Exception as exception:
print(f"Error during model conversion: {exception}")
raise
def test_inference(model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> None:
"""Run test inference"""
# Test prompts
test_prompts = [
"The quick brown fox",
]
for prompt in test_prompts:
try:
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Generate
from time import time
start = time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
# temperature=0.7,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
# use_cache=True,
)
elasped = time() - start
print(f"generation time: {elasped}s")
# Decode
print(outputs)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"\nPrompt: {prompt}")
print(f"Generated: {generated_text}\n")
except Exception as exception:
print(f"Error during inference: {str(exception)}")
raise
def save_converted_model(model: AutoModelForCausalLM, output_dir: str) -> None:
"""Save the converted model"""
print(f"Saving converted model to {output_dir}")
model.save_pretrained(output_dir)
def main():
# Configuration
model_name = "HuggingFaceTB/SmolLM2-135M"
# model_name = "openlm-research/open_llama_3b_v2"
output_dir = "./converted_model"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
try:
# Load model and tokenizer
model, tokenizer = setup_model(model_name, device)
# Print original model info
print("Original model config:")
print(f"\t- Hidden size: {model.config.hidden_size}")
print(f"\t- Number of attention heads: {model.config.num_attention_heads}")
# Test the original model
test_inference(model, tokenizer)
# Convert to differential attention
model = convert_to_diff_attention(model)
model.to(model.device)
print("Model conversion completed")
# Test the converted model
test_inference(model, tokenizer)
# Save converted model
save_converted_model(model, output_dir)
except Exception as exception:
print(f"Error during test: {str(exception)}")
raise
if __name__ == "__main__":
main()

View File

@@ -1,20 +1,81 @@
"""Differential attention conversion logic for a huggingface pre-trained model."""
import logging
from typing import Union
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaSdpaAttention
from transformers.models.mistral.modeling_mistral import MistralAttention
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
from .multihead_diffattn import DifferentialAttention
from .multihead_diffattn import (
LlamaDifferentialAttention,
LlamaDifferentialSdpaAttention,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def copy_attention_weights(
old_attn: Union[LlamaAttention, LlamaSdpaAttention],
new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention],
zero_init: bool = True,
) -> None:
"""
Copy weights from old attention layer to new differential attention layer.
Copies old weights to Q1 and K1, zeros out Q2 and K2 for exact equivalence
to original attention mechanism.
"""
# For Q projection (Q1 and Q2)
new_q = torch.empty_like(new_attn.q_proj.weight.data)
new_q[: new_attn.hidden_size] = old_attn.q_proj.weight.data # Q1
if zero_init:
new_q[new_attn.hidden_size :] = 0
else:
nn.init.normal_(new_q[new_attn.hidden_size :], mean=0, std=0.1)
new_attn.q_proj.weight.data.copy_(new_q)
# For K projection (K1 and K2)
old_kv_size = old_attn.k_proj.weight.data.size(0) # Size for 3 heads
new_k = torch.empty_like(new_attn.k_proj.weight.data)
new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1
if zero_init:
new_k[old_kv_size:] = 0
else:
nn.init.normal_(new_k[old_kv_size:], mean=0, std=0.1)
new_attn.k_proj.weight.data.copy_(new_k)
# For V projection (single V)
new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data)
# Output projection remains the same
new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data)
# Zero out lambda parameters for exact equivalence
if zero_init:
nn.init.zeros_(new_attn.lambda_q1)
nn.init.zeros_(new_attn.lambda_k1)
nn.init.zeros_(new_attn.lambda_q2)
nn.init.zeros_(new_attn.lambda_k2)
new_attn.lambda_init = 0.0
logger.debug(
"Copied positive attention weights from %s to %s",
type(old_attn).__name__,
type(new_attn).__name__,
)
def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""
attention_patterns = (LlamaAttention, MistralAttention, MixtralAttention)
attention_patterns = (
LlamaAttention,
LlamaSdpaAttention,
MistralAttention,
MixtralAttention,
)
layer_idx = 0
# Get model dtype from existing weights
@@ -29,13 +90,22 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
layer_type = type(child).__name__
logger.info(f"Converting attention layer {layer_idx}: {layer_type}")
# Choose appropriate differential attention class
if isinstance(child, LlamaSdpaAttention):
attention_class = LlamaDifferentialSdpaAttention
else:
attention_class = LlamaDifferentialAttention
# Create new diff attn layer
new_attention = DifferentialAttention(
new_attention = attention_class(
config=module.config if hasattr(module, "config") else model.config,
layer_idx=layer_idx,
dtype=model_dtype,
)
# Copy weights from old attention to new attention
copy_attention_weights(child, new_attention)
# Replace the layer
setattr(module, name, new_attention)
layer_idx += 1

View File

@@ -6,9 +6,9 @@ from typing import Any, Optional, Tuple
import torch
import torch.nn.functional as F
import transformers
from torch import nn
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm
from transformers.models.llama.modeling_llama import (
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
@@ -34,7 +34,7 @@ def lambda_init_fn(depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)
class DifferentialAttention(nn.Module):
class LlamaDifferentialAttention(nn.Module):
"""Differential Attention implementation as described in the Diff Transformer paper.
This implements a modified attention mechanism that computes the difference between
@@ -54,7 +54,6 @@ class DifferentialAttention(nn.Module):
config: Model configuration object containing hidden size, number of heads etc.
layer_idx: Index of this layer in the transformer stack
dtype: Data type for the layer parameters
is_causal: Whether to use causal (masked) attention
"""
def __init__(
@@ -62,43 +61,52 @@ class DifferentialAttention(nn.Module):
config: Any,
layer_idx: int,
dtype: torch.dtype,
is_causal: bool = True,
):
super().__init__()
self.config = config
self.layer_idx = layer_idx
# Base model dimensions
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.is_causal = is_causal
# self.head_dim = self.hidden_size // self.num_heads
self.head_dim = self.hidden_size // self.num_heads // 2
self.num_key_value_heads = getattr(
config, "num_key_value_heads", self.num_heads
)
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.scaling = (self.head_dim) ** -0.5
self.base_num_heads = config.num_attention_heads
self.base_num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
# Initialize projections with correct dtype
self.scaling = self.head_dim**-0.5
self.layer_idx = layer_idx
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
# For Q1 and Q2
self.q_proj = nn.Linear(
self.hidden_size, self.hidden_size, bias=False, dtype=dtype
self.hidden_size,
self.hidden_size * 2,
bias=False,
dtype=dtype,
)
# For K1 and K2
self.k_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.num_key_value_groups,
bias=False,
dtype=dtype,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.num_key_value_groups,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
bias=False,
dtype=dtype,
)
# Single V projection
self.v_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
bias=False,
dtype=dtype,
)
# Output projection
self.o_proj = nn.Linear(
self.hidden_size, self.hidden_size, bias=False, dtype=dtype
self.hidden_size,
self.hidden_size,
bias=False,
dtype=dtype,
)
# Initialize differential attention parameters
@@ -116,7 +124,6 @@ class DifferentialAttention(nn.Module):
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
)
self.subln = RMSNorm(2 * self.head_dim, eps=1e-5)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
def forward(
@@ -126,6 +133,7 @@ class DifferentialAttention(nn.Module):
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, # pylint: disable=unused-argument
@@ -134,97 +142,261 @@ class DifferentialAttention(nn.Module):
Optional[torch.Tensor],
Optional[tuple[torch.Tensor, torch.Tensor]],
]:
bsz, tgt_len, _ = hidden_states.size()
bsz, q_len, _ = hidden_states.size()
# Project queries, keys and values
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
# Project to Q1,Q2 and K1,K2
qp = self.q_proj(hidden_states)
kp = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Reshape for attention
q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(bsz, tgt_len, 2 * self.num_key_value_heads, self.head_dim).transpose(
1, 2
)
v = v.view(bsz, tgt_len, self.num_key_value_heads, 2 * self.head_dim).transpose(
1, 2
)
# Split into Q1,Q2 and K1,K2
q1, q2 = qp.chunk(2, dim=-1)
k1, k2 = kp.chunk(2, dim=-1)
# Generate or unpack cos, sin for rotary positional embeddings
# Reshape Q1,Q2 for attention
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape K1,K2 for attention
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape V
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Apply rotary embeddings
if position_embeddings is None:
if position_ids is None:
position_ids = torch.arange(
0, tgt_len, dtype=torch.long, device=q.device
)
cos, sin = self.rotary_emb(q, position_ids)
position_ids = torch.arange(q_len, device=q1.device)
cos, sin = self.rotary_emb(q1, position_ids)
else:
cos, sin = position_embeddings
# Need to adjust cos, sin to match the halved head_dim
cos = cos[..., : self.head_dim]
sin = sin[..., : self.head_dim]
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
# Update cache and get back concatenated states
k = torch.stack([k1, k2], dim=1)
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1)
# Prepare for attention
k = repeat_kv(k, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
# Repeat KV heads to match Q heads
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
# Scale query
q = q * self.scaling
# Calculate attention scores for both parts
# NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor
# instead of sqrt(head_dim). This could be set on the class as `self.scaling`.
attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(
self.head_dim
)
attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(
self.head_dim
)
# Calculate attention scores
attn_weights = torch.matmul(q, k.transpose(-1, -2))
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
attn_weights1 = attn_weights1 + causal_mask
attn_weights2 = attn_weights2 + causal_mask
# Apply causal mask
if attention_mask is None:
attention_mask = torch.triu(
torch.full((tgt_len, tgt_len), float("-inf"), device=q.device),
diagonal=1,
).type_as(attn_weights)
attn_weights = torch.nan_to_num(attn_weights)
attn_weights = attn_weights + attention_mask
# Apply softmax
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
attn_weights
# Apply softmax separately as per paper
attn_weights1 = F.softmax(attn_weights1, dim=-1, dtype=torch.float32).type_as(
attn_weights1
)
attn_weights2 = F.softmax(attn_weights2, dim=-1, dtype=torch.float32).type_as(
attn_weights2
)
attn_weights1 = F.dropout(
attn_weights1, p=self.attention_dropout, training=self.training
)
attn_weights2 = F.dropout(
attn_weights2, p=self.attention_dropout, training=self.training
)
# Calculate lambda
lambda_1 = torch.exp(
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
).type_as(q)
).type_as(q1)
lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q)
).type_as(q1)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
# Apply differential attention
attn_weights = attn_weights.view(
bsz, self.num_heads, 2, -1, attn_weights.size(-1)
)
attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
# Compute differential attention (following paper's formula)
attn_weights = attn_weights1 - lambda_full * attn_weights2
# Apply attention to values
# Apply attention weights to values
attn = torch.matmul(attn_weights, v)
# Apply sublayer norm
attn = self.subln(attn).type_as(attn)
# Apply sublayer norm and scaling
# NOTE(Dan): The differential transformers paper applies sublayer normalization at this
# point, but this is typically done outside of the attention layer. It would look something
# like: `attn = self.subln(attn).type_as(attn)`, using `LlamaRMSNorm` or similar.
attn = attn * (1 - self.lambda_init)
# Reshape and project output
attn = attn.transpose(1, 2).reshape(
bsz, tgt_len, self.num_heads * 2 * self.head_dim
)
# Reshape to output
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn = self.o_proj(attn)
# Return in exact format expected by LLaMA
if output_attentions:
return attn, attn_weights, past_key_value
return attn, None, past_key_value
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
"""Differential Attention implementation as described in the Diff Transformer paper.
This implements the same logic as `LlamaDifferentialAttention`, but uses
`scaled_dot_product_attention` instead of "manually" computing it under the hood.
This implements a modified attention mechanism that computes the difference between
two attention patterns, scaled by learned lambda parameters. The mechanism helps
reduce noise in the attention weights for irrelevant / less relevant tokens.
Key components:
- Split head dimension for differential computation
- Learned lambda parameters that control attention scaling
- Sublayer normalization on the attention output
See:
- https://arxiv.org/abs/2410.05258
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
Args:
config: Model configuration object containing hidden size, number of heads etc.
layer_idx: Index of this layer in the transformer stack
dtype: Data type for the layer parameters
"""
def forward(
self,
hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size]
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, # pylint: disable=unused-argument
) -> tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[tuple[torch.Tensor, torch.Tensor]],
]:
if output_attentions:
transformers.logger.warning_once(
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
# Project to Q1,Q2 and K1,K2
qp = self.q_proj(hidden_states)
kp = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Split into Q1,Q2 and K1,K2
q1, q2 = qp.chunk(2, dim=-1)
k1, k2 = kp.chunk(2, dim=-1)
# Reshape Q1,Q2 for attention
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
# Reshape K1,K2 for attention
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
# Reshape V
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
# Apply rotary embeddings
if position_embeddings is None:
if position_ids is None:
position_ids = torch.arange(q_len, device=q1.device)
cos, sin = self.rotary_emb(q1, position_ids)
else:
cos, sin = position_embeddings
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k = torch.stack([k1, k2], dim=1)
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1)
# Repeat KV heads to match Q heads
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
causal_mask = None
if attention_mask is not None:
causal_mask = attention_mask
causal_mask = causal_mask[:, :, :, : k1.shape[-2]]
# SDPA with memory-efficient backend requires contiguous inputs on CUDA
if q1.device.type == "cuda" and causal_mask is not None:
q1, q2 = q1.contiguous(), q2.contiguous()
k1, k2 = k1.contiguous(), k2.contiguous()
v = v.contiguous()
# Calculate attention using SDPA
is_causal = attention_mask is None and q_len > 1
attn_output1 = F.scaled_dot_product_attention(
q1,
k1,
v,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output2 = F.scaled_dot_product_attention(
q2,
k2,
v,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
# Calculate lambda
lambda_1 = torch.exp(
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
).type_as(q1)
lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
# Combine the attention outputs
attn = attn_output1 - lambda_full * attn_output2
# Apply sublayer norm and scaling
attn = attn * (1 - self.lambda_init)
# Reshape to output
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn = self.o_proj(attn)
if output_attentions:
return (
attn,
None,
past_key_value,
) # Note: can't return attn_weights with SDPA
return attn, None, past_key_value