moving diff attn code to separate repo

This commit is contained in:
Dan Saunders
2025-01-23 21:33:53 +00:00
parent 28694219a5
commit 7145d52d99
7 changed files with 24 additions and 1305 deletions

View File

@@ -1,208 +0,0 @@
"""CLI to convert a transformers model's attention layers to differential attention layers."""
import logging
import warnings
from pathlib import Path
from time import time
from typing import Union
import fire
import torch
import yaml
from colorama import Fore
from dotenv import load_dotenv
from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
LlamaDifferentialConfig,
LlamaDifferentialForCausalLM,
)
from axolotl.utils.yaml import dump_yaml_preserved_order
LOG = logging.getLogger(__name__)
def test_inference(model, tokenizer, prompt="The quick brown fox"):
"""Run test inference and return generation time"""
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()}
start = time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
)
elapsed = time() - start
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
LOG.info("Prompt: %s", prompt)
LOG.info("Generated: %s", generated_text)
LOG.info("Generation time: %.2fs", elapsed)
return elapsed, generated_text
def convert_diff_transformer(cfg, cli_args, config_path):
assert not (
cli_args.split_heads and cli_args.zero_init
), "Both `split_heads` and `zero_init` cannot be `True`"
assert not (
cli_args.zero_init and cli_args.mirror_weights
), "Both `zero_init` and `mirror_weights` cannot be `True`"
debug_info = {}
# Load model and tokenizer
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model.to(cfg.device, dtype=cfg.torch_dtype)
# Log original model info
LOG.info(
"Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d",
model.config.hidden_size,
model.config.num_attention_heads,
)
# Test original model
if cli_args.debug:
LOG.info("Testing original model...")
debug_info["orig_time"], debug_info["orig_text"] = test_inference(
model, tokenizer
)
try:
# Convert attention
LOG.info("Converting to differential attention...")
config = LlamaDifferentialConfig(
**model.config.__dict__,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
mirror_weights=cli_args.mirror_weights,
)
model = LlamaDifferentialForCausalLM.from_llama(model, config)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise
# Test converted model
if cli_args.debug:
LOG.info("Testing converted model...")
debug_info["conv_time"], debug_info["conv_text"] = test_inference(
model, tokenizer
)
# Save if requested
if cfg.output_dir:
# Save model and tokenizer
LOG.info("Saving converted model to %s", cfg.output_dir)
model.save_pretrained(cfg.output_dir)
tokenizer.save_pretrained(cfg.output_dir)
# Modify config to reflect new path / differential attention
output_config_path = Path(cfg.output_dir) / "axolotl_config.yml"
LOG.info("Saving updated config to %s", output_config_path)
with open(config_path, "r", encoding="utf-8") as file:
modified_cfg = yaml.safe_load(file) or {}
modified_cfg["base_model"] = cfg.output_dir
modified_cfg["diff_attention"] = True
plugin_class = (
"axolotl.integrations.diff_transformer.DifferentialTransformerPlugin"
)
if "plugins" in modified_cfg:
modified_cfg["plugins"].append(plugin_class)
else:
modified_cfg["plugins"] = [plugin_class]
# Write out the updated axolotl config while preserving original ordering / formatting
dump_yaml_preserved_order(
data=modified_cfg,
reference_yaml_path=config_path,
output_path=output_config_path,
)
else:
LOG.info("Not saving converted model to disk")
LOG.info("Pass --output-dir path/to/save to save model")
if cli_args.debug:
LOG.info(
Fore.GREEN
+ "Conversion successful!\n"
+ f"Original generation time: {debug_info['orig_time']:.2f}s\n"
+ f"Converted generation time: {debug_info['conv_time']:.2f}s"
+ Fore.RESET
)
if debug_info["orig_text"] == debug_info["conv_text"]:
LOG.info(
Fore.GREEN
+ "Generations match!\n"
+ "Model generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ Fore.RESET
)
debug_info["generations_match"] = True
else:
message = (
"Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['conv_text']}\n"
+ "*" * 50
+ "\n"
)
debug_info["generations_match"] = False
if cli_args.zero_init and not cli_args.sublayer_norm:
LOG.info(Fore.RED + message + Fore.RESET)
debug_info["match_expected"] = True
else:
LOG.info(
Fore.YELLOW
+ message
+ "However, this is expected since --zero-init"
+ " and --no-sublayer-norm were not passed."
+ Fore.RESET
)
debug_info["match_expected"] = False
return model, debug_info
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
print_axolotl_text_art()
cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
convert_diff_transformer(cfg, cli_args, config)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,4 +1,5 @@
"""CLI definition for various axolotl commands."""
# pylint: disable=redefined-outer-name
import subprocess # nosec B404
from typing import Optional
@@ -256,7 +257,12 @@ def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.integrations.convert_diff_transformer import do_cli
try:
from axolotl_diff_transformer.convert_diff_transformer import do_cli
except ImportError as exc:
raise ImportError(
"axolotl-diff-transformer not found, please install it: https://github.com/axolotl-ai-cloud/diff-transformer"
) from exc
do_cli(config=config, **kwargs)

View File

@@ -1,5 +1,19 @@
# Differential Transformer
### Installation
```shell
pip install git+https://github.com/axolotl-ai-cloud/diff-transformer.git
```
Editable:
```shell
git clone git@github.com:axolotl-ai-cloud/diff-transformer.git
cd diff-transformer
pip install -e .
```
### Usage
**Note:** The following with be set in the model config output by the `axolotl convert-diff-transformer` command.

View File

@@ -24,7 +24,9 @@ class DifferentialTransformerPlugin(BasePlugin):
to register differential attention custom modeling implementation to `AutoConfig`
and `AutoModel`.
"""
from .modeling_diff_attn import register_diff_attn
from axolotl_diff_transformer.modeling.modeling_diff_attn import (
register_diff_attn,
)
register_diff_attn()

View File

@@ -1,694 +0,0 @@
"""Re-implemention of differential attention from the Differential Transformer paper
(https://arxiv.org/abs/2410.05258)."""
# pylint: disable=invalid-name
import logging
import math
from typing import Any
import torch
import torch.nn.functional as F
from torch import nn
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
)
logging.basicConfig(level=logging.INFO)
LOG = logging.getLogger(__name__)
try:
from flash_attn.flash_attn_interface import flash_attn_func
FLASH_ATTENTION_AVAILABLE = True
except ImportError:
FLASH_ATTENTION_AVAILABLE = False
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Repeats key/value heads to match the number of query heads in multi-head attention.
Args:
x: Input tensor of shape `(batch_size, num_kv_heads, seq_len, head_dim)`.
n_rep: Number of times to repeat each head.
Returns:
Tensor with repeated heads of shape `(batch_size, num_kv_heads * n_rep,
seq_len, head_dim)`.
If `n_rep` is 1, returns the input tensor unchanged.
"""
batch_size, n_kv_heads, slen, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, None, :, :]
.expand(batch_size, n_kv_heads, n_rep, slen, head_dim)
.reshape(batch_size, n_kv_heads * n_rep, slen, head_dim)
)
def lambda_init_fn(depth: int) -> float:
"""
Lambda mixing parameter init function from the "Differential Transformer" paper.
Args:
depth: Index of layer to init lambda parameter.
Returns:
Lambda initialization value (decreasing with `depth`).
"""
return 0.8 - 0.6 * math.exp(-0.3 * depth)
class LlamaDifferentialAttentionBase(nn.Module):
"""
Base class for differential attention implementations.
This class implements the core differential attention mechanism used in Llama models.
It supports both split heads and double projection modes for attention computation.
"""
def __init__(self, config: Any, layer_idx: int):
"""
Initializes the differential attention module.
Args:
config: Model configuration object containing hyperparameters, including:
- hidden_size: The size of hidden states.
- num_attention_heads: Number of attention heads.
- num_key_value_heads: Number of key/value heads.
- attention_bias: Whether to use bias in attention projections.
- split_heads: Whether to use split heads mode.
- rms_norm_eps: Epsilon for RMS normalization.
layer_idx: The index of this layer in the model.
Note:
The initialization process consists of four steps:
1. Configuration initialization (`_init_config`)
2. Projection layers initialization (`_init_projections`)
3. Differential parameters initialization (`_init_differential_params`)
4. Normalization layers initialization (`_init_normalization`)
"""
super().__init__()
self.config = config
self._init_config(layer_idx)
self._init_projections()
self._init_differential_params()
self._init_normalization()
# For logging
self.attn1 = None
self.attn2 = None
self.lambda_full = None
def _init_config(self, layer_idx: int) -> None:
"""
Initializes configuration parameters for the attention layer. Sets up various
dimension sizes and head counts based on the provided config. Handles both
split heads and double projection modes.
In split heads mode, the number of heads is divided by 2 (rounding down), which
differs from the original implementation that required an even number.
Args:
layer_idx: Index of the current layer.
"""
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
self.base_num_heads = self.config.num_attention_heads
self.base_num_kv_heads = self.config.num_key_value_heads
self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads
self.layer_idx = layer_idx
if self.config.split_heads:
self.heads_per_component = self.base_num_heads // 2
self.kv_heads_per_component = self.base_num_kv_heads // 2
self.value_head_dim = 2 * self.head_dim
else:
self.heads_per_component = self.base_num_heads
self.kv_heads_per_component = self.base_num_kv_heads
self.value_head_dim = self.head_dim
def _init_projections(self) -> None:
"""
Initializes the query, key, value, and output projection layers.
Creates linear transformations for Q, K, V projections with dimensions
depending on whether split heads or double projection mode is used.
The output projection combines the attention heads back to model dimension.
"""
if self.config.split_heads:
q_out_dim = self.config.hidden_size
k_out_dim = self.head_dim * self.base_num_kv_heads
else:
q_out_dim = self.config.hidden_size * 2
k_out_dim = self.head_dim * self.base_num_kv_heads * 2
self.q_proj = nn.Linear(
self.config.hidden_size, q_out_dim, bias=self.config.attention_bias
)
self.k_proj = nn.Linear(
self.config.hidden_size, k_out_dim, bias=self.config.attention_bias
)
self.v_proj = nn.Linear(
self.config.hidden_size,
self.head_dim * self.base_num_kv_heads,
bias=self.config.attention_bias,
)
self.o_proj = nn.Linear(
self.base_num_heads * self.head_dim,
self.config.hidden_size,
bias=self.config.attention_bias,
)
def _init_differential_params(self) -> None:
"""
Initializes parameters specific to differential attention.
Creates learnable parameters for the differential attention mechanism:
- Mixing parameter for negative attention component warmup phase.
- Lambda parameters for queries and keys.
- Initial lambda value based on layer index.
- Rotary position embedding layer.
"""
self.diff_attn_mix = 1.0 # Default to full mixing
self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx)),
requires_grad=False,
)
self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_k1 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_q2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
def _init_normalization(self) -> None:
"""
Initializes normalization layers for the attention mechanism.
Sets up either RMS normalization or identity transformation based on config.
The normalization is applied to the sublayer output if enabled.
"""
sublayer_norm = getattr(self.config, "sublayer_norm", True)
if sublayer_norm:
self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps)
else:
self.subln = nn.Identity()
def _prepare_attention_inputs(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepares input tensors for attention computation.
Projects input hidden states to query, key, and value spaces, then reshapes
them for multi-head attention processing.
Args:
hidden_states: Input tensor of shape `(batch_size, seq_len,
hidden_size)`.
Returns:
tuple: Tuple containing:
- q1: Positive attention query component
- q2: Negative attention query component
- k1: Positive attention key component
- k2: Negative attention key component
- v: Value tensor
"""
bsz, q_len, _ = hidden_states.size()
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q1, q2 = q.chunk(2, dim=-1)
k1, k2 = k.chunk(2, dim=-1)
q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
k1 = k1.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
1, 2
)
k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
1, 2
)
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
return q1, q2, k1, k2, v
def _apply_rotary_embeddings(
self,
q1: torch.Tensor,
q2: torch.Tensor,
k1: torch.Tensor,
k2: torch.Tensor,
position_ids: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
Applies rotary positional embeddings to queries and keys.
Args:
q1: Positive attention query component.
q2: Negative attention query component.
k1: Positive attention key component.
k2: Negative attention key component.
position_ids: Token position indices.
position_embeddings: Pre-computed rotary embeddings (cos, sin).
Returns:
tuple: Tuple containing:
- q1: Positive attention query with positional encoding.
- q2: Negative attention query with positional encoding.
- k1: Positive attention key with positional encoding.
- k2: Negative attention key with positional encoding.
- cos: Cosine part of rotary embeddings.
- sin: Sine part of rotary embeddings.
"""
if position_embeddings is None:
LOG.warning(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(q1, position_ids)
else:
cos, sin = position_embeddings
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
return q1, q2, k1, k2, cos, sin
def _handle_cache(
self,
k1: torch.Tensor,
k2: torch.Tensor,
v: torch.Tensor,
past_key_value: Cache | None,
cache_kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Handles key-value caching for autoregressive generation and the repetition of
key-value heads to match the number of query heads.
Args:
k1: Positive attention key component.
k2: Negative attention key component.
v: Value tensor.
past_key_value: Cache object for storing previous key-value pairs.
cache_kwargs: Additional arguments for cache handling.
Returns:
tuple: Tuple containing:
- k1: Processed positive attention key component.
- k2: Processed negative attention key component.
- v: Processed value tensor.
"""
if past_key_value is not None:
k = torch.stack([k1, k2], dim=1)
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1)
k1 = repeat_kv(k1, self.num_key_value_groups)
k2 = repeat_kv(k2, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
if self.config.split_heads:
v = torch.cat(torch.chunk(v, 2, dim=1), dim=-1)
return k1, k2, v
def _compute_lambda(self, q1: torch.Tensor) -> torch.Tensor:
"""
Computes lambda values for differential attention.
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are computed
from the learned parameters. `diff_attn_mix` is multiplied through the result
for negative attention component warmup phase (if applicable).
Args:
q1: Positive attention query component, used for type casting.
Returns:
Computed lambda value for differential attention.
"""
lambda_1 = torch.exp(
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
).type_as(q1)
lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
return self.diff_attn_mix * lambda_full
def _process_attention_output(
self, attn: torch.Tensor, bsz: int, q_len: int
) -> torch.Tensor:
"""
Processes and projects the attention output. Applies sublayer normalization,
scales by (1 - λ_init), and projects back to model dimension.
Args:
attn: Raw attention output.
bsz: Batch size.
q_len: Query sequence length.
Returns:
Processed attention output of shape (batch_size, seq_len, hidden_size)
"""
attn = self.subln(attn)
# NOTE: this may need to be added back in, but doesn't interact well with
# `diff_attn_mix`, and doesn't allow us to preserve the original model output.
# attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
return self.o_proj(attn)
class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
"""
Standard implementation of differential attention.
This class implements the standard differential attention mechanism using
explicit matrix multiplications for the attention computation.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using standard matrix multiplication operations.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- Attention weights if output_attentions is True, else None.
- Updated key-value cache if use_cache is True, else None.
"""
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# Standard attention computation
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
attn1 = attn1 + causal_mask
attn2 = attn2 + causal_mask
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
dropout_p = self.config.attention_dropout if self.training else 0.0
attn1 = F.dropout(attn1, p=dropout_p, training=self.training)
attn2 = F.dropout(attn2, p=dropout_p, training=self.training)
lambda_full = self._compute_lambda(q1)
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
if output_attentions:
attn_weights = attn1 - lambda_full * attn2
attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1)
return attn, attn_weights, past_key_value
return attn, None, past_key_value
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
"""
SDPA-based implementation of differential attention.
This class implements differential attention using PyTorch's scaled_dot_product_attention
for improved performance on supported hardware.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using PyTorch's scaled dot product attention.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- None for attention weights (SDPA doesn't support output_attentions).
- Updated key-value cache if use_cache is True, else None.
"""
if output_attentions:
LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but "
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
+ "`output_attentions=True`. Falling back to the eager attention implementation."
)
# pylint: disable=duplicate-code
return LlamaDifferentialAttention.forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# SDPA-specific attention computation
causal_mask = (
None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]]
)
is_causal = attention_mask is None and q_len > 1
dropout_p = self.config.attention_dropout if self.training else 0.0
if q1.device.type == "cuda" and causal_mask is not None:
q1, q2 = q1.contiguous(), q2.contiguous()
k1, k2 = k1.contiguous(), k2.contiguous()
v = v.contiguous()
attn1 = F.scaled_dot_product_attention(
q1, k1, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
)
attn2 = F.scaled_dot_product_attention(
q2, k2, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
)
lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
return attn, None, past_key_value
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
"""
Flash Attention 2-based implementation of differential attention.
This class implements differential attention using Flash Attention 2 for maximum
performance on supported hardware.
"""
def __init__(self, *args, **kwargs):
"""
Initializes the Flash Attention 2 differential attention module.
Args:
*args: Positional arguments passed to parent class.
**kwargs: Keyword arguments passed to parent class.
Raises:
ImportError: If flash-attn library is not installed.
"""
if not FLASH_ATTENTION_AVAILABLE:
raise ImportError(
"LlamaDifferentialFlashAttention2 requires flash-attn library. "
"Please install with `pip install flash-attn --no-build-isolation`"
)
super().__init__(*args, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using Flash Attention 2.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- None for attention weights (Flash Attention doesn't support output_attentions).
- Updated key-value cache if use_cache is True, else None.
"""
if output_attentions:
LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but "
+ "flash attenion does not support `output_attentions=True`. Falling back "
+ "to the eager attention implementation."
)
# pylint: disable=duplicate-code
return LlamaDifferentialAttention.forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# Flash Attention specific processing
q1, q2 = q1.transpose(1, 2), q2.transpose(1, 2)
k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2)
v = v.transpose(1, 2)
dropout_p = self.config.attention_dropout if self.training else 0.0
if self.config.split_heads:
v1, v2 = v.chunk(2, dim=-1)
attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True)
attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = flash_attn_func(q2, k2, v1, dropout_p=dropout_p, causal=True)
attn22 = flash_attn_func(q2, k2, v2, dropout_p=dropout_p, causal=True)
attn2 = torch.cat([attn21, attn22], dim=-1)
else:
attn1 = flash_attn_func(q1, k1, v, dropout_p=dropout_p, causal=True)
attn2 = flash_attn_func(q2, k2, v, dropout_p=dropout_p, causal=True)
attn1, attn2 = attn1.transpose(1, 2), attn2.transpose(1, 2)
lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
return attn, None, past_key_value

View File

@@ -1,401 +0,0 @@
"""
Modeling for differential transformers.
This module implements differential attention variants of the LLaMA model,
providing various attention implementations for improved performance.
"""
import logging
import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
from .diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
logger = logging.getLogger(__name__)
class LlamaDifferentialConfig(LlamaConfig):
"""
Configuration class for Differential LLaMA model.
Extends the base LLaMA configuration with additional parameters for differential
attention mechanisms.
"""
model_type = "llama-differential"
def __init__(
self,
split_heads: bool = False,
sublayer_norm: bool = True,
zero_init: bool = False,
mirror_weights: bool = False,
**kwargs,
):
"""
Initialize differential LLaMA configuration.
Args:
split_heads: Whether to use split heads mode for attention computation.
sublayer_norm: Whether to apply normalization to sublayers.
zero_init: Whether to initialize new weights to zero.
mirror_weights: Whether to copy the positive attention component weights to
the negative attention component.
**kwargs: Additional arguments passed to LlamaConfig.
"""
super().__init__(**kwargs)
self.split_heads = split_heads
self.sublayer_norm = sublayer_norm
self.zero_init = zero_init
self.mirror_weights = mirror_weights
self.architectures = ["LlamaDifferentialModel"]
self._attn_implementations = {
"eager": "differential_eager",
"sdpa": "differential_sdpa",
"flash_attention_2": "differential_flash_attention_2",
}
class LlamaDifferentialModel(LlamaModel):
"""
LlamaModel with differential attention.
This class extends the base LLaMA model by replacing standard attention with
differential attention mechanisms.
"""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def __init__(self, config: LlamaDifferentialConfig):
"""
Initialize a differential LLaMA model.
Args:
config: Configuration object for the model.
Raises:
ValueError: If specified attention implementation is not supported.
"""
super().__init__(config)
# Handle attention implementation
attn_impl = config._attn_implementation or "eager"
if attn_impl in config._attn_implementations:
attn_impl = config._attn_implementations[attn_impl]
# Validate attention implementation
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_impl not in valid_impls:
raise ValueError(f"Invalid attention implementation: {attn_impl}")
# Replace standard attention with differential attention in each layer
attn_classes = {
"differential_eager": LlamaDifferentialAttention,
"differential_sdpa": LlamaDifferentialSdpaAttention,
"differential_flash_attention_2": LlamaDifferentialFlashAttention2,
}
attn_class = attn_classes.get(attn_impl, LlamaDifferentialAttention)
for idx, layer in enumerate(self.layers):
layer.self_attn = attn_class(config, idx)
@classmethod
# pylint: disable=protected-access
def _autoset_attn_implementation(
cls,
config: LlamaDifferentialConfig,
**kwargs, # pylint: disable=unused-argument
) -> LlamaDifferentialConfig:
"""
Automatically set the attention implementation based on config.
Args:
config: Model configuration object.
**kwargs: Additional arguments (unused).
Returns:
Updated configuration object.
Raises:
ValueError: If specified attention implementation is not supported.
"""
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)
# Map standard types to differential types if mapping exists
if attn_implementation in config._attn_implementations:
config._attn_implementation = config._attn_implementations[
attn_implementation
]
return config
# If no mapping, validate it's a valid differential type
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message)
return config
@classmethod
def from_llama(
cls,
model: LlamaModel | LlamaForCausalLM,
config: LlamaDifferentialConfig | None = None,
) -> "LlamaDifferentialModel":
"""
Convert a `LlamaModel` to use differential attention.
Args:
model: Base LLaMA model to convert.
config: Configuration for differential attention. If `None`, created from
base model config.
Returns:
Converted model with differential attention.
Raises:
ValueError: If number of heads is not even when using `split_heads` mode.
"""
logger.info(f"Converting {type(model).__name__} to {cls.__name__}")
# Handle LlamaForCausalLM
if isinstance(model, LlamaForCausalLM):
model = model.model
if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__)
logger.debug(f"Created config: {config}")
# Validate head counts if using split heads mode
if config.split_heads:
if config.num_attention_heads % 2 != 0:
raise ValueError(
f"Number of attention heads ({config.num_attention_heads}) must be even "
"when using split_heads=True"
)
if config.num_key_value_heads % 2 != 0:
raise ValueError(
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
"when using split_heads=True"
)
new_model = cls(config)
# Copy all weights except attention
logger.debug("Copying embeddings and norm")
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
new_model.norm.load_state_dict(model.norm.state_dict())
logger.debug("Copying layer weights")
for layer_idx, (new_layer, old_layer) in enumerate(
zip(new_model.layers, model.layers)
):
# Copy everything except attention weights
new_layer.mlp.load_state_dict(old_layer.mlp.state_dict())
new_layer.input_layernorm.load_state_dict(
old_layer.input_layernorm.state_dict()
)
new_layer.post_attention_layernorm.load_state_dict(
old_layer.post_attention_layernorm.state_dict()
)
# Handle attention weights
new_layer.self_attn.v_proj.load_state_dict(
old_layer.self_attn.v_proj.state_dict()
)
new_layer.self_attn.o_proj.load_state_dict(
old_layer.self_attn.o_proj.state_dict()
)
# Get the original projection sizes
old_q_size = old_layer.self_attn.q_proj.weight.size(0)
old_k_size = old_layer.self_attn.k_proj.weight.size(0)
if not config.split_heads:
logger.debug(
f"Layer {layer_idx}: Copying Q/K projections with sizes {old_q_size}, {old_k_size}"
)
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
old_layer.self_attn.q_proj.weight.data
)
new_layer.self_attn.k_proj.weight.data[:old_k_size].copy_(
old_layer.self_attn.k_proj.weight.data
)
if config.zero_init:
logger.debug(f"Layer {layer_idx}: Zero initializing")
with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_()
new_layer.self_attn.lambda_q1.zero_()
new_layer.self_attn.lambda_k1.zero_()
new_layer.self_attn.lambda_q2.zero_()
new_layer.self_attn.lambda_k2.zero_()
new_layer.self_attn.lambda_init.zero_()
elif config.mirror_weights:
# Mirror weights for second component
new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_(
old_layer.self_attn.q_proj.weight.data
)
new_layer.self_attn.k_proj.weight.data[old_k_size:].copy_(
old_layer.self_attn.k_proj.weight.data
)
logger.info("Conversion complete")
return new_model
class LlamaDifferentialForCausalLM(LlamaForCausalLM):
"""
`LlamaForCausalLM` with differential attention.
This class extends the base LLaMA causal language model by incorporating
differential attention mechanisms.
"""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def __init__(self, config: LlamaDifferentialConfig):
"""
Initialize a differential LLaMA model for causal language modeling.
Args:
config: Configuration object for the model.
"""
super().__init__(config)
self.model = LlamaDifferentialModel(config)
@classmethod
# pylint: disable=protected-access
def _autoset_attn_implementation(
cls,
config: LlamaDifferentialConfig,
**kwargs, # pylint: disable=unused-argument
) -> LlamaDifferentialConfig:
"""
Automatically set the attention implementation based on config.
Args:
config: Model configuration object.
**kwargs: Additional arguments (unused).
Returns:
Updated configuration object.
Raises:
ValueError: If specified attention implementation is not supported.
"""
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)
# Map standard types to differential types if mapping exists
if attn_implementation in config._attn_implementations:
config._attn_implementation = config._attn_implementations[
attn_implementation
]
return config
# If no mapping, validate it's a valid differential type
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message)
return config
@classmethod
def from_llama(
cls, model: LlamaForCausalLM, config: LlamaDifferentialConfig | None = None
) -> "LlamaDifferentialForCausalLM":
"""
Convert a `LlamaForCausalLM` to use differential attention.
Args:
model: Base LLaMA model to convert.
config: Configuration for differential attention. If `None`, created from
base model config.
Returns:
Converted model with differential attention.
Raises:
ValueError: If number of heads is not even when using `split_heads` mode.
"""
if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__)
# Validate head counts if using split heads mode
if config.split_heads:
if config.num_attention_heads % 2 != 0:
raise ValueError(
f"Number of attention heads ({config.num_attention_heads}) must be even "
"when using split_heads=True"
)
if config.num_key_value_heads % 2 != 0:
raise ValueError(
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
"when using split_heads=True"
)
new_model = cls(config)
new_model.model = LlamaDifferentialModel.from_llama(model.model, config)
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
return new_model
def register_diff_attn() -> None:
"""
Register differential attention components with the transformers library.
This function registers the differential attention configurations and model classes
with the Auto* classes from `transformers`, making them available through the
standard model loading pipeline.
"""
# Register configs
AutoConfig.register("llama-differential", LlamaDifferentialConfig)
# Register models
AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel)
AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM)
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
LLAMA_ATTENTION_CLASSES[
"differential_flash_attention_2"
] = LlamaDifferentialFlashAttention2