added modeling code; cleanup + refactor

This commit is contained in:
Dan Saunders
2024-12-23 14:14:51 -05:00
committed by Dan Saunders
parent fcbfa86373
commit 5b90da0be3
3 changed files with 439 additions and 34 deletions

View File

@@ -50,7 +50,7 @@ def copy_attention_weights(
new_attn.q_proj.weight.data.copy_(new_q)
# For K projection (K1 and K2)
old_kv_size = old_attn.k_proj.weight.data.size(0) # Size for 3 heads
old_kv_size = old_attn.k_proj.weight.data.size(0)
new_k = torch.empty_like(new_attn.k_proj.weight.data)
new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1
if zero_init:
@@ -99,6 +99,7 @@ def convert_to_diff_attn(
# Iterate through module children, convert any attn layers to diff attn
for name, child in module.named_children():
child_class_name = type(child).__name__
if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]:
# Find matching attention class by name
for orig_class, diff_class in ATTENTION_MAPPING.items():

View File

@@ -7,7 +7,6 @@ from typing import Any, Optional, Tuple
import torch
import torch.nn.functional as F
from flash_attn.flash_attn_interface import flash_attn_func
from torch import nn
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import (
@@ -17,7 +16,14 @@ from transformers.models.llama.modeling_llama import (
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
LOG = logging.getLogger(__name__)
try:
from flash_attn.flash_attn_interface import flash_attn_func
FLASH_ATTENTION_AVAILABLE = True
except ImportError:
FLASH_ATTENTION_AVAILABLE = False
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -35,11 +41,12 @@ def lambda_init_fn(depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)
class DifferentialAttentionBase(nn.Module):
class LlamaDifferentialAttentionBase(nn.Module):
"""Base class for differential attention implementations."""
def __init__(self, config: Any, layer_idx: int):
super().__init__()
self.config = config
self._init_config(config, layer_idx)
self._init_projections()
self._init_differential_params()
@@ -59,9 +66,9 @@ class DifferentialAttentionBase(nn.Module):
if config.split_heads:
# Split heads mode - single projections
self.head_dim = config.hidden_size // config.num_attention_heads // 2
self.head_dim = config.hidden_size // config.num_attention_heads
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
# implementation, which asserts `self.base_num_heads` is even.
# implementation, which asserts `self.base_num_heads` is even
self.heads_per_component = self.base_num_heads // 2
self.value_head_dim = 2 * self.head_dim
else:
@@ -110,36 +117,43 @@ class DifferentialAttentionBase(nn.Module):
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.rotary_emb = LlamaRotaryEmbedding(
self.max_position_embeddings, self.head_dim, self.rope_theta
)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
def _init_normalization(self, config):
"""Initialize normalization layers."""
sublayer_norm = getattr(config, "sublayer_norm", True)
self.subln = (
LlamaRMSNorm(self.value_head_dim, eps=1e-5)
if sublayer_norm
else nn.Identity()
)
if sublayer_norm:
self.subln = LlamaRMSNorm(self.value_head_dim, eps=config.rms_norm_eps)
else:
self.subln = nn.Identity()
def _prepare_attention_inputs(self, hidden_states: torch.Tensor):
"""Prepare inputs for attention computation."""
bsz, q_len, _ = hidden_states.size()
# Project and split
qp = self.q_proj(hidden_states)
kp = self.k_proj(hidden_states)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q1, q2 = qp.chunk(2, dim=-1)
k1, k2 = kp.chunk(2, dim=-1)
q1, q2 = q.chunk(2, dim=-1)
k1, k2 = k.chunk(2, dim=-1)
# Reshape
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
v = v.view(bsz, q_len, -1, self.value_head_dim).transpose(1, 2)
q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
k1 = k1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
k2 = k2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
v = v.view(bsz, q_len, self.heads_per_component, self.value_head_dim).transpose(
1, 2
)
return q1, q2, k1, k2, v
@@ -148,16 +162,16 @@ class DifferentialAttentionBase(nn.Module):
):
"""Apply rotary embeddings to queries and keys."""
if position_embeddings is None:
if position_ids is None:
position_ids = torch.arange(q1.size(-2), device=q1.device)
LOG.warning(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(q1, position_ids)
else:
cos, sin = position_embeddings
if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
@@ -195,7 +209,7 @@ class DifferentialAttentionBase(nn.Module):
return self.o_proj(attn)
class LlamaDifferentialAttention(DifferentialAttentionBase):
class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
"""Standard implementation of differential attention."""
def forward(
@@ -237,15 +251,16 @@ class LlamaDifferentialAttention(DifferentialAttentionBase):
lambda_full = self._compute_lambda(q1)
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
attn = self._process_attention_output(attn, bsz, q_len)
if output_attentions:
return attn, attn1 - lambda_full * attn2, past_key_value
attn_weights = attn1 - lambda_full * attn2
attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1)
return attn, attn_weights, past_key_value
return attn, None, past_key_value
class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
"""SDPA-based implementation of differential attention."""
# pylint: disable=duplicate-code
@@ -262,6 +277,11 @@ class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
**kwargs, # pylint: disable=unused-argument
):
if output_attentions:
LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but "
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
+ "`output_attentions=True`. Falling back to the eager attention implementation."
)
return LlamaDifferentialAttention.forward(
self,
hidden_states,
@@ -309,9 +329,18 @@ class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
return attn, None, past_key_value
class LlamaDifferentialFlashAttention2(DifferentialAttentionBase):
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
"""Flash Attention 2-based implementation of differential attention."""
def __init__(self, *args, **kwargs):
if not FLASH_ATTENTION_AVAILABLE:
raise ImportError(
"LlamaDifferentialFlashAttention2 requires flash-attn library. "
"Please install with `pip install flash-attn --no-build-isolation`"
)
super().__init__(*args, **kwargs)
# pylint: disable=duplicate-code
def forward(
self,
@@ -326,6 +355,11 @@ class LlamaDifferentialFlashAttention2(DifferentialAttentionBase):
**kwargs, # pylint: disable=unused-argument
):
if output_attentions:
LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but "
+ "flash attenion does not support `output_attentions=True`. Falling back "
+ "to the eager attention implementation."
)
return LlamaDifferentialAttention.forward(
self,
hidden_states,

View File

@@ -0,0 +1,370 @@
"""Modeling for differential transformers."""
import math
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaMLP,
LlamaModel,
LlamaPreTrainedModel,
LlamaRMSNorm,
)
from .diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialAttentionBase,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
class LlamaDifferentialConfig(LlamaConfig):
"""Configuration class for Differential LLaMA model."""
def __init__(
self,
split_heads: bool = False,
sublayer_norm: bool = True,
zero_init: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.split_heads = split_heads
self.sublayer_norm = sublayer_norm
self.zero_init = zero_init
self.architectures = ["LlamaDifferentialModel"]
self._attn_implementations = {
"eager": "differential_eager",
"sdpa": "differential_sdpa",
"flash_attention_2": "differential_flash_attention_2",
}
class LlamaDifferentialPreTrainedModel(LlamaPreTrainedModel):
"""Base class for differential LLaMA models."""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LlamaDifferentialAttentionBase, LlamaModel)):
module.gradient_checkpointing = value
def lambda_init_fn(depth: int) -> float:
"""Initialize lambda parameter based on layer depth."""
return 0.8 - 0.6 * math.exp(-0.3 * depth)
class LlamaDifferentialModel(LlamaDifferentialPreTrainedModel):
"""Differential version of the LLaMA model."""
def __init__(self, config: LlamaDifferentialConfig):
super().__init__(config)
# Map attn implementations to classes
self.attn_implementation_to_class = {
"differential_eager": LlamaDifferentialAttention,
"differential_sdpa": LlamaDifferentialSdpaAttention,
"differential_flash_attention_2": LlamaDifferentialFlashAttention2,
}
# Get correct attention implementation
attn_implementation = getattr(config, "_attn_implementation", "eager")
if attn_implementation in config._attn_implementations:
attn_implementation = config._attn_implementations[attn_implementation]
self.attention_class = self.attn_implementation_to_class.get(
attn_implementation, LlamaDifferentialAttention
)
# Initialize model components
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, config.pad_token_id
)
self.layers = nn.ModuleList(
[
LlamaDifferentialDecoderLayer(
config=config, layer_idx=i, attention_class=self.attention_class
)
for i in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# Check if either input_ids or inputs_embeds is provided
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
device = input_ids.device
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
device = inputs_embeds.device
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
# Initialize past_key_values if needed
if past_key_values is None:
past_key_values = tuple([None] * len(self.layers))
# Create attention mask if not provided
if attention_mask is not None:
attention_mask = self._prepare_attention_mask(
attention_mask, (batch_size, seq_length), device
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
# Initialize lists to store outputs
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_cache = () if use_cache else None
for _, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
if output_hidden_states:
all_hidden_states += (hidden_states,) # type: ignore
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_cache += (layer_outputs[-1],) # type: ignore
if output_attentions:
all_self_attns += (layer_outputs[1],) # type: ignore
# Add last hidden state
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,) # type: ignore
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _prepare_attention_mask(
self,
attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
device: torch.device,
) -> torch.Tensor:
"""Prepare attention mask for computing attention."""
# Create causal mask
# [batch_size, seq_length] -> [batch_size, 1, seq_length, seq_length]
combined_attention_mask = None
_, seq_length = input_shape
if self.config.is_decoder:
seq_ids = torch.arange(seq_length, device=device)
causal_mask = (
seq_ids[None, None, :].repeat(1, seq_length, 1)
<= seq_ids[None, :, None]
)
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1:] != (seq_length, seq_length):
causal_mask = causal_mask[:, :seq_length, :seq_length]
# Extend attention mask
combined_attention_mask = (
causal_mask[None, None, :, :] * attention_mask[:, None, None, :]
)
else:
combined_attention_mask = attention_mask[:, None, None, :]
return combined_attention_mask
@classmethod
def from_llama(
cls,
llama_model: LlamaModel,
differential_config: Optional[LlamaDifferentialConfig] = None,
) -> "LlamaDifferentialModel":
"""Convert a standard LLaMA model to use differential attention."""
if differential_config is None:
# pylint: disable=protected-access
differential_config = LlamaDifferentialConfig.from_pretrained(
llama_model.config._name_or_path
)
# Create new model
new_model = cls(differential_config)
# Copy non-attention weights directly
new_model.embed_tokens.load_state_dict(llama_model.embed_tokens.state_dict())
new_model.norm.load_state_dict(llama_model.norm.state_dict())
# Copy layer weights, handling attention layers specially
for new_layer, old_layer in zip(new_model.layers, llama_model.layers):
# Copy self-attention weights with special handling
if differential_config.split_heads:
# Split heads mode
new_layer.self_attn.q_proj.weight.data.copy_(
old_layer.self_attn.q_proj.weight.data
)
new_layer.self_attn.k_proj.weight.data.copy_(
old_layer.self_attn.k_proj.weight.data
)
else:
# Double projection mode - copy weights to positive components
new_layer.self_attn.q_proj.weight.data[
: differential_config.hidden_size
].copy_(old_layer.self_attn.q_proj.weight.data)
new_layer.self_attn.k_proj.weight.data[
: differential_config.hidden_size
].copy_(old_layer.self_attn.k_proj.weight.data)
# Zero out relevant parameters for exact equivalence
if differential_config.zero_init:
old_kv_size = old_layer.self_attn.k_proj.weight.data.size(0)
new_layer.self_attn.q_proj.weight.data[
new_layer.self_attn.hidden_size :
] = 0
new_layer.self_attn.k_proj.weight.data[old_kv_size:] = 0
nn.init.zeros_(new_layer.self_attn.lambda_q1)
nn.init.zeros_(new_layer.self_attn.lambda_k1)
nn.init.zeros_(new_layer.self_attn.lambda_q2)
nn.init.zeros_(new_layer.self_attn.lambda_k2)
nn.init.zeros_(new_layer.self_attn.lambda_init)
# Copy remaining weights
new_layer.self_attn.v_proj.load_state_dict(
old_layer.self_attn.v_proj.state_dict()
)
new_layer.self_attn.o_proj.load_state_dict(
old_layer.self_attn.o_proj.state_dict()
)
# Copy MLP and layer norm weights
new_layer.mlp.load_state_dict(old_layer.mlp.state_dict())
new_layer.input_layernorm.load_state_dict(
old_layer.input_layernorm.state_dict()
)
new_layer.post_attention_layernorm.load_state_dict(
old_layer.post_attention_layernorm.state_dict()
)
return new_model
class LlamaDifferentialDecoderLayer(nn.Module):
"""Custom decoder layer for diffrential Llama model."""
def __init__(
self, config: LlamaDifferentialConfig, layer_idx: int, attention_class
):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = attention_class(config, layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Layer forward pass with differential attention.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,) # type: ignore
if use_cache:
outputs += (present_key_value,) # type: ignore
return outputs # type: ignore