Files
axolotl/src/axolotl/integrations/diff_transformer/convert.py
2025-01-10 16:28:51 +00:00

119 lines
4.1 KiB
Python

"""Differential attention conversion logic for a huggingface pre-trained model."""
import logging
from typing import Union
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaSdpaAttention
from transformers.models.mistral.modeling_mistral import MistralAttention
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
from .multihead_diffattn import (
LlamaDifferentialAttention,
LlamaDifferentialSdpaAttention,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def copy_attention_weights(
old_attn: Union[LlamaAttention, LlamaSdpaAttention],
new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention],
zero_init: bool = True,
) -> None:
"""
Copy weights from old attention layer to new differential attention layer.
Copies old weights to Q1 and K1, zeros out Q2 and K2 for exact equivalence
to original attention mechanism.
"""
# For Q projection (Q1 and Q2)
new_q = torch.empty_like(new_attn.q_proj.weight.data)
new_q[: new_attn.hidden_size] = old_attn.q_proj.weight.data # Q1
if zero_init:
new_q[new_attn.hidden_size :] = 0
else:
nn.init.normal_(new_q[new_attn.hidden_size :], mean=0, std=0.1)
new_attn.q_proj.weight.data.copy_(new_q)
# For K projection (K1 and K2)
old_kv_size = old_attn.k_proj.weight.data.size(0) # Size for 3 heads
new_k = torch.empty_like(new_attn.k_proj.weight.data)
new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1
if zero_init:
new_k[old_kv_size:] = 0
else:
nn.init.normal_(new_k[old_kv_size:], mean=0, std=0.1)
new_attn.k_proj.weight.data.copy_(new_k)
# For V projection (single V)
new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data)
# Output projection remains the same
new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data)
# Zero out lambda parameters for exact equivalence
if zero_init:
nn.init.zeros_(new_attn.lambda_q1)
nn.init.zeros_(new_attn.lambda_k1)
nn.init.zeros_(new_attn.lambda_q2)
nn.init.zeros_(new_attn.lambda_k2)
new_attn.lambda_init = 0.0
logger.debug(
"Copied positive attention weights from %s to %s",
type(old_attn).__name__,
type(new_attn).__name__,
)
def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""
attention_patterns = (
LlamaAttention,
LlamaSdpaAttention,
MistralAttention,
MixtralAttention,
)
layer_idx = 0
# Get model dtype from existing weights
model_dtype = next(model.parameters()).dtype
def convert_module(module):
nonlocal layer_idx
# Iterate through module children, convert any attn layers to diff attn
for name, child in module.named_children():
if isinstance(child, attention_patterns):
layer_type = type(child).__name__
logger.info(f"Converting attention layer {layer_idx}: {layer_type}")
# Choose appropriate differential attention class
if isinstance(child, LlamaSdpaAttention):
attention_class = LlamaDifferentialSdpaAttention
else:
attention_class = LlamaDifferentialAttention
# Create new diff attn layer
new_attention = attention_class(
config=module.config if hasattr(module, "config") else model.config,
layer_idx=layer_idx,
dtype=model_dtype,
)
# Copy weights from old attention to new attention
copy_attention_weights(child, new_attention)
# Replace the layer
setattr(module, name, new_attention)
layer_idx += 1
elif len(list(child.children())) > 0:
convert_module(child)
convert_module(model)
logger.info(f"Converted {layer_idx} attention layers to differential attention")
return model