moving diff attn code to separate repo
This commit is contained in:
@@ -1,208 +0,0 @@
|
||||
"""CLI to convert a transformers model's attention layers to differential attention layers."""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from time import time
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import yaml
|
||||
from colorama import Fore
|
||||
from dotenv import load_dotenv
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
|
||||
LlamaDifferentialConfig,
|
||||
LlamaDifferentialForCausalLM,
|
||||
)
|
||||
from axolotl.utils.yaml import dump_yaml_preserved_order
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
||||
"""Run test inference and return generation time"""
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
inputs = {k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()}
|
||||
|
||||
start = time()
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=20,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
)
|
||||
elapsed = time() - start
|
||||
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
LOG.info("Prompt: %s", prompt)
|
||||
LOG.info("Generated: %s", generated_text)
|
||||
LOG.info("Generation time: %.2fs", elapsed)
|
||||
|
||||
return elapsed, generated_text
|
||||
|
||||
|
||||
def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
assert not (
|
||||
cli_args.split_heads and cli_args.zero_init
|
||||
), "Both `split_heads` and `zero_init` cannot be `True`"
|
||||
assert not (
|
||||
cli_args.zero_init and cli_args.mirror_weights
|
||||
), "Both `zero_init` and `mirror_weights` cannot be `True`"
|
||||
|
||||
debug_info = {}
|
||||
|
||||
# Load model and tokenizer
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
# Log original model info
|
||||
LOG.info(
|
||||
"Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d",
|
||||
model.config.hidden_size,
|
||||
model.config.num_attention_heads,
|
||||
)
|
||||
|
||||
# Test original model
|
||||
if cli_args.debug:
|
||||
LOG.info("Testing original model...")
|
||||
debug_info["orig_time"], debug_info["orig_text"] = test_inference(
|
||||
model, tokenizer
|
||||
)
|
||||
|
||||
try:
|
||||
# Convert attention
|
||||
LOG.info("Converting to differential attention...")
|
||||
|
||||
config = LlamaDifferentialConfig(
|
||||
**model.config.__dict__,
|
||||
zero_init=cli_args.zero_init,
|
||||
sublayer_norm=cli_args.sublayer_norm,
|
||||
split_heads=cli_args.split_heads,
|
||||
mirror_weights=cli_args.mirror_weights,
|
||||
)
|
||||
model = LlamaDifferentialForCausalLM.from_llama(model, config)
|
||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
except Exception as exc:
|
||||
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
||||
raise
|
||||
|
||||
# Test converted model
|
||||
if cli_args.debug:
|
||||
LOG.info("Testing converted model...")
|
||||
debug_info["conv_time"], debug_info["conv_text"] = test_inference(
|
||||
model, tokenizer
|
||||
)
|
||||
|
||||
# Save if requested
|
||||
if cfg.output_dir:
|
||||
# Save model and tokenizer
|
||||
LOG.info("Saving converted model to %s", cfg.output_dir)
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
tokenizer.save_pretrained(cfg.output_dir)
|
||||
|
||||
# Modify config to reflect new path / differential attention
|
||||
output_config_path = Path(cfg.output_dir) / "axolotl_config.yml"
|
||||
LOG.info("Saving updated config to %s", output_config_path)
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as file:
|
||||
modified_cfg = yaml.safe_load(file) or {}
|
||||
|
||||
modified_cfg["base_model"] = cfg.output_dir
|
||||
modified_cfg["diff_attention"] = True
|
||||
plugin_class = (
|
||||
"axolotl.integrations.diff_transformer.DifferentialTransformerPlugin"
|
||||
)
|
||||
if "plugins" in modified_cfg:
|
||||
modified_cfg["plugins"].append(plugin_class)
|
||||
else:
|
||||
modified_cfg["plugins"] = [plugin_class]
|
||||
|
||||
# Write out the updated axolotl config while preserving original ordering / formatting
|
||||
dump_yaml_preserved_order(
|
||||
data=modified_cfg,
|
||||
reference_yaml_path=config_path,
|
||||
output_path=output_config_path,
|
||||
)
|
||||
else:
|
||||
LOG.info("Not saving converted model to disk")
|
||||
LOG.info("Pass --output-dir path/to/save to save model")
|
||||
|
||||
if cli_args.debug:
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
+ "Conversion successful!\n"
|
||||
+ f"Original generation time: {debug_info['orig_time']:.2f}s\n"
|
||||
+ f"Converted generation time: {debug_info['conv_time']:.2f}s"
|
||||
+ Fore.RESET
|
||||
)
|
||||
|
||||
if debug_info["orig_text"] == debug_info["conv_text"]:
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
+ "Generations match!\n"
|
||||
+ "Model generation:\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ f"{debug_info['orig_text']}\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ Fore.RESET
|
||||
)
|
||||
debug_info["generations_match"] = True
|
||||
else:
|
||||
message = (
|
||||
"Generations do not match.\n"
|
||||
+ "Original generation:\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ f"{debug_info['orig_text']}\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ "Converted generation:\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ f"{debug_info['conv_text']}\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
)
|
||||
debug_info["generations_match"] = False
|
||||
|
||||
if cli_args.zero_init and not cli_args.sublayer_norm:
|
||||
LOG.info(Fore.RED + message + Fore.RESET)
|
||||
debug_info["match_expected"] = True
|
||||
else:
|
||||
LOG.info(
|
||||
Fore.YELLOW
|
||||
+ message
|
||||
+ "However, this is expected since --zero-init"
|
||||
+ " and --no-sublayer-norm were not passed."
|
||||
+ Fore.RESET
|
||||
)
|
||||
debug_info["match_expected"] = False
|
||||
|
||||
return model, debug_info
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
print_axolotl_text_art()
|
||||
|
||||
cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
|
||||
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
|
||||
convert_diff_transformer(cfg, cli_args, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
@@ -1,4 +1,5 @@
|
||||
"""CLI definition for various axolotl commands."""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
import subprocess # nosec B404
|
||||
from typing import Optional
|
||||
@@ -256,7 +257,12 @@ def convert_diff_transformer(config: str, **kwargs):
|
||||
"""Convert model attention layers to differential attention layers."""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
from axolotl.cli.integrations.convert_diff_transformer import do_cli
|
||||
try:
|
||||
from axolotl_diff_transformer.convert_diff_transformer import do_cli
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"axolotl-diff-transformer not found, please install it: https://github.com/axolotl-ai-cloud/diff-transformer"
|
||||
) from exc
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@@ -1,5 +1,19 @@
|
||||
# Differential Transformer
|
||||
|
||||
### Installation
|
||||
|
||||
```shell
|
||||
pip install git+https://github.com/axolotl-ai-cloud/diff-transformer.git
|
||||
```
|
||||
|
||||
Editable:
|
||||
|
||||
```shell
|
||||
git clone git@github.com:axolotl-ai-cloud/diff-transformer.git
|
||||
cd diff-transformer
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
**Note:** The following with be set in the model config output by the `axolotl convert-diff-transformer` command.
|
||||
|
||||
@@ -24,7 +24,9 @@ class DifferentialTransformerPlugin(BasePlugin):
|
||||
to register differential attention custom modeling implementation to `AutoConfig`
|
||||
and `AutoModel`.
|
||||
"""
|
||||
from .modeling_diff_attn import register_diff_attn
|
||||
from axolotl_diff_transformer.modeling.modeling_diff_attn import (
|
||||
register_diff_attn,
|
||||
)
|
||||
|
||||
register_diff_attn()
|
||||
|
||||
|
||||
@@ -1,694 +0,0 @@
|
||||
"""Re-implemention of differential attention from the Differential Transformer paper
|
||||
(https://arxiv.org/abs/2410.05258)."""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaRMSNorm,
|
||||
LlamaRotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
|
||||
FLASH_ATTENTION_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLASH_ATTENTION_AVAILABLE = False
|
||||
|
||||
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
Repeats key/value heads to match the number of query heads in multi-head attention.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape `(batch_size, num_kv_heads, seq_len, head_dim)`.
|
||||
n_rep: Number of times to repeat each head.
|
||||
|
||||
Returns:
|
||||
Tensor with repeated heads of shape `(batch_size, num_kv_heads * n_rep,
|
||||
seq_len, head_dim)`.
|
||||
If `n_rep` is 1, returns the input tensor unchanged.
|
||||
"""
|
||||
batch_size, n_kv_heads, slen, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, None, :, :]
|
||||
.expand(batch_size, n_kv_heads, n_rep, slen, head_dim)
|
||||
.reshape(batch_size, n_kv_heads * n_rep, slen, head_dim)
|
||||
)
|
||||
|
||||
|
||||
def lambda_init_fn(depth: int) -> float:
|
||||
"""
|
||||
Lambda mixing parameter init function from the "Differential Transformer" paper.
|
||||
|
||||
Args:
|
||||
depth: Index of layer to init lambda parameter.
|
||||
|
||||
Returns:
|
||||
Lambda initialization value (decreasing with `depth`).
|
||||
"""
|
||||
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
||||
|
||||
|
||||
class LlamaDifferentialAttentionBase(nn.Module):
|
||||
"""
|
||||
Base class for differential attention implementations.
|
||||
|
||||
This class implements the core differential attention mechanism used in Llama models.
|
||||
It supports both split heads and double projection modes for attention computation.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Any, layer_idx: int):
|
||||
"""
|
||||
Initializes the differential attention module.
|
||||
|
||||
Args:
|
||||
config: Model configuration object containing hyperparameters, including:
|
||||
- hidden_size: The size of hidden states.
|
||||
- num_attention_heads: Number of attention heads.
|
||||
- num_key_value_heads: Number of key/value heads.
|
||||
- attention_bias: Whether to use bias in attention projections.
|
||||
- split_heads: Whether to use split heads mode.
|
||||
- rms_norm_eps: Epsilon for RMS normalization.
|
||||
layer_idx: The index of this layer in the model.
|
||||
|
||||
Note:
|
||||
The initialization process consists of four steps:
|
||||
1. Configuration initialization (`_init_config`)
|
||||
2. Projection layers initialization (`_init_projections`)
|
||||
3. Differential parameters initialization (`_init_differential_params`)
|
||||
4. Normalization layers initialization (`_init_normalization`)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self._init_config(layer_idx)
|
||||
self._init_projections()
|
||||
self._init_differential_params()
|
||||
self._init_normalization()
|
||||
|
||||
# For logging
|
||||
self.attn1 = None
|
||||
self.attn2 = None
|
||||
self.lambda_full = None
|
||||
|
||||
def _init_config(self, layer_idx: int) -> None:
|
||||
"""
|
||||
Initializes configuration parameters for the attention layer. Sets up various
|
||||
dimension sizes and head counts based on the provided config. Handles both
|
||||
split heads and double projection modes.
|
||||
|
||||
In split heads mode, the number of heads is divided by 2 (rounding down), which
|
||||
differs from the original implementation that required an even number.
|
||||
|
||||
Args:
|
||||
layer_idx: Index of the current layer.
|
||||
"""
|
||||
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
self.base_num_heads = self.config.num_attention_heads
|
||||
self.base_num_kv_heads = self.config.num_key_value_heads
|
||||
self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
if self.config.split_heads:
|
||||
self.heads_per_component = self.base_num_heads // 2
|
||||
self.kv_heads_per_component = self.base_num_kv_heads // 2
|
||||
self.value_head_dim = 2 * self.head_dim
|
||||
else:
|
||||
self.heads_per_component = self.base_num_heads
|
||||
self.kv_heads_per_component = self.base_num_kv_heads
|
||||
self.value_head_dim = self.head_dim
|
||||
|
||||
def _init_projections(self) -> None:
|
||||
"""
|
||||
Initializes the query, key, value, and output projection layers.
|
||||
|
||||
Creates linear transformations for Q, K, V projections with dimensions
|
||||
depending on whether split heads or double projection mode is used.
|
||||
The output projection combines the attention heads back to model dimension.
|
||||
"""
|
||||
if self.config.split_heads:
|
||||
q_out_dim = self.config.hidden_size
|
||||
k_out_dim = self.head_dim * self.base_num_kv_heads
|
||||
else:
|
||||
q_out_dim = self.config.hidden_size * 2
|
||||
k_out_dim = self.head_dim * self.base_num_kv_heads * 2
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.config.hidden_size, q_out_dim, bias=self.config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.config.hidden_size, k_out_dim, bias=self.config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.config.hidden_size,
|
||||
self.head_dim * self.base_num_kv_heads,
|
||||
bias=self.config.attention_bias,
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.base_num_heads * self.head_dim,
|
||||
self.config.hidden_size,
|
||||
bias=self.config.attention_bias,
|
||||
)
|
||||
|
||||
def _init_differential_params(self) -> None:
|
||||
"""
|
||||
Initializes parameters specific to differential attention.
|
||||
|
||||
Creates learnable parameters for the differential attention mechanism:
|
||||
- Mixing parameter for negative attention component warmup phase.
|
||||
- Lambda parameters for queries and keys.
|
||||
- Initial lambda value based on layer index.
|
||||
- Rotary position embedding layer.
|
||||
"""
|
||||
self.diff_attn_mix = 1.0 # Default to full mixing
|
||||
|
||||
self.lambda_init = nn.Parameter(
|
||||
torch.full((), lambda_init_fn(self.layer_idx)),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.lambda_q1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
self.lambda_k1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
self.lambda_q2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
self.lambda_k2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
||||
|
||||
def _init_normalization(self) -> None:
|
||||
"""
|
||||
Initializes normalization layers for the attention mechanism.
|
||||
|
||||
Sets up either RMS normalization or identity transformation based on config.
|
||||
The normalization is applied to the sublayer output if enabled.
|
||||
"""
|
||||
sublayer_norm = getattr(self.config, "sublayer_norm", True)
|
||||
if sublayer_norm:
|
||||
self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps)
|
||||
else:
|
||||
self.subln = nn.Identity()
|
||||
|
||||
def _prepare_attention_inputs(
|
||||
self, hidden_states: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Prepares input tensors for attention computation.
|
||||
|
||||
Projects input hidden states to query, key, and value spaces, then reshapes
|
||||
them for multi-head attention processing.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor of shape `(batch_size, seq_len,
|
||||
hidden_size)`.
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing:
|
||||
- q1: Positive attention query component
|
||||
- q2: Negative attention query component
|
||||
- k1: Positive attention key component
|
||||
- k2: Negative attention key component
|
||||
- v: Value tensor
|
||||
"""
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
q1, q2 = q.chunk(2, dim=-1)
|
||||
k1, k2 = k.chunk(2, dim=-1)
|
||||
|
||||
q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
k1 = k1.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
return q1, q2, k1, k2, v
|
||||
|
||||
def _apply_rotary_embeddings(
|
||||
self,
|
||||
q1: torch.Tensor,
|
||||
q2: torch.Tensor,
|
||||
k1: torch.Tensor,
|
||||
k2: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
"""
|
||||
Applies rotary positional embeddings to queries and keys.
|
||||
|
||||
Args:
|
||||
q1: Positive attention query component.
|
||||
q2: Negative attention query component.
|
||||
k1: Positive attention key component.
|
||||
k2: Negative attention key component.
|
||||
position_ids: Token position indices.
|
||||
position_embeddings: Pre-computed rotary embeddings (cos, sin).
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing:
|
||||
- q1: Positive attention query with positional encoding.
|
||||
- q2: Negative attention query with positional encoding.
|
||||
- k1: Positive attention key with positional encoding.
|
||||
- k2: Negative attention key with positional encoding.
|
||||
- cos: Cosine part of rotary embeddings.
|
||||
- sin: Sine part of rotary embeddings.
|
||||
"""
|
||||
if position_embeddings is None:
|
||||
LOG.warning(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(q1, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||
|
||||
return q1, q2, k1, k2, cos, sin
|
||||
|
||||
def _handle_cache(
|
||||
self,
|
||||
k1: torch.Tensor,
|
||||
k2: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
past_key_value: Cache | None,
|
||||
cache_kwargs: dict,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Handles key-value caching for autoregressive generation and the repetition of
|
||||
key-value heads to match the number of query heads.
|
||||
|
||||
Args:
|
||||
k1: Positive attention key component.
|
||||
k2: Negative attention key component.
|
||||
v: Value tensor.
|
||||
past_key_value: Cache object for storing previous key-value pairs.
|
||||
cache_kwargs: Additional arguments for cache handling.
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing:
|
||||
- k1: Processed positive attention key component.
|
||||
- k2: Processed negative attention key component.
|
||||
- v: Processed value tensor.
|
||||
"""
|
||||
if past_key_value is not None:
|
||||
k = torch.stack([k1, k2], dim=1)
|
||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||
k1, k2 = k.unbind(dim=1)
|
||||
|
||||
k1 = repeat_kv(k1, self.num_key_value_groups)
|
||||
k2 = repeat_kv(k2, self.num_key_value_groups)
|
||||
v = repeat_kv(v, self.num_key_value_groups)
|
||||
if self.config.split_heads:
|
||||
v = torch.cat(torch.chunk(v, 2, dim=1), dim=-1)
|
||||
|
||||
return k1, k2, v
|
||||
|
||||
def _compute_lambda(self, q1: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Computes lambda values for differential attention.
|
||||
|
||||
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are computed
|
||||
from the learned parameters. `diff_attn_mix` is multiplied through the result
|
||||
for negative attention component warmup phase (if applicable).
|
||||
|
||||
Args:
|
||||
q1: Positive attention query component, used for type casting.
|
||||
|
||||
Returns:
|
||||
Computed lambda value for differential attention.
|
||||
"""
|
||||
lambda_1 = torch.exp(
|
||||
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_2 = torch.exp(
|
||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
return self.diff_attn_mix * lambda_full
|
||||
|
||||
def _process_attention_output(
|
||||
self, attn: torch.Tensor, bsz: int, q_len: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Processes and projects the attention output. Applies sublayer normalization,
|
||||
scales by (1 - λ_init), and projects back to model dimension.
|
||||
|
||||
Args:
|
||||
attn: Raw attention output.
|
||||
bsz: Batch size.
|
||||
q_len: Query sequence length.
|
||||
|
||||
Returns:
|
||||
Processed attention output of shape (batch_size, seq_len, hidden_size)
|
||||
"""
|
||||
attn = self.subln(attn)
|
||||
# NOTE: this may need to be added back in, but doesn't interact well with
|
||||
# `diff_attn_mix`, and doesn't allow us to preserve the original model output.
|
||||
# attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
|
||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
|
||||
|
||||
return self.o_proj(attn)
|
||||
|
||||
|
||||
class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
|
||||
"""
|
||||
Standard implementation of differential attention.
|
||||
|
||||
This class implements the standard differential attention mechanism using
|
||||
explicit matrix multiplications for the attention computation.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: Cache | None = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False, # pylint: disable=unused-argument
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""
|
||||
Computes differential attention using standard matrix multiplication operations.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor containing sequence to attend to.
|
||||
attention_mask: Mask to avoid attention on padding tokens.
|
||||
position_ids: Indices of positions for positional embeddings.
|
||||
past_key_value: Cached key and value tensors for autoregressive decoding.
|
||||
output_attentions: Whether to return attention weights.
|
||||
use_cache: Whether to use cached key/value states.
|
||||
cache_position: Position indices for cached states.
|
||||
position_embeddings: Pre-computed positional embeddings.
|
||||
**kwargs: Additional arguments passed to the forward call.
|
||||
|
||||
Returns:
|
||||
tuple containing:
|
||||
- Output tensor after attention computation.
|
||||
- Attention weights if output_attentions is True, else None.
|
||||
- Updated key-value cache if use_cache is True, else None.
|
||||
"""
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
|
||||
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
|
||||
q1, q2, k1, k2, position_ids, position_embeddings
|
||||
)
|
||||
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
|
||||
|
||||
# Standard attention computation
|
||||
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
||||
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
|
||||
attn1 = attn1 + causal_mask
|
||||
attn2 = attn2 + causal_mask
|
||||
|
||||
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
|
||||
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
|
||||
|
||||
dropout_p = self.config.attention_dropout if self.training else 0.0
|
||||
attn1 = F.dropout(attn1, p=dropout_p, training=self.training)
|
||||
attn2 = F.dropout(attn2, p=dropout_p, training=self.training)
|
||||
|
||||
lambda_full = self._compute_lambda(q1)
|
||||
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
|
||||
attn = self._process_attention_output(attn, bsz, q_len)
|
||||
|
||||
# Save for logging
|
||||
self.attn1 = attn1
|
||||
self.attn2 = attn2
|
||||
self.lambda_full = lambda_full
|
||||
|
||||
if output_attentions:
|
||||
attn_weights = attn1 - lambda_full * attn2
|
||||
attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1)
|
||||
return attn, attn_weights, past_key_value
|
||||
return attn, None, past_key_value
|
||||
|
||||
|
||||
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
|
||||
"""
|
||||
SDPA-based implementation of differential attention.
|
||||
|
||||
This class implements differential attention using PyTorch's scaled_dot_product_attention
|
||||
for improved performance on supported hardware.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: Cache | None = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""
|
||||
Computes differential attention using PyTorch's scaled dot product attention.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor containing sequence to attend to.
|
||||
attention_mask: Mask to avoid attention on padding tokens.
|
||||
position_ids: Indices of positions for positional embeddings.
|
||||
past_key_value: Cached key and value tensors for autoregressive decoding.
|
||||
output_attentions: Whether to return attention weights.
|
||||
use_cache: Whether to use cached key/value states.
|
||||
cache_position: Position indices for cached states.
|
||||
position_embeddings: Pre-computed positional embeddings.
|
||||
**kwargs: Additional arguments passed to the forward call.
|
||||
|
||||
Returns:
|
||||
tuple containing:
|
||||
- Output tensor after attention computation.
|
||||
- None for attention weights (SDPA doesn't support output_attentions).
|
||||
- Updated key-value cache if use_cache is True, else None.
|
||||
"""
|
||||
if output_attentions:
|
||||
LOG.warning(
|
||||
"LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but "
|
||||
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
+ "`output_attentions=True`. Falling back to the eager attention implementation."
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
return LlamaDifferentialAttention.forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
|
||||
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
|
||||
q1, q2, k1, k2, position_ids, position_embeddings
|
||||
)
|
||||
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
|
||||
|
||||
# SDPA-specific attention computation
|
||||
causal_mask = (
|
||||
None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]]
|
||||
)
|
||||
is_causal = attention_mask is None and q_len > 1
|
||||
dropout_p = self.config.attention_dropout if self.training else 0.0
|
||||
|
||||
if q1.device.type == "cuda" and causal_mask is not None:
|
||||
q1, q2 = q1.contiguous(), q2.contiguous()
|
||||
k1, k2 = k1.contiguous(), k2.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
attn1 = F.scaled_dot_product_attention(
|
||||
q1, k1, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
attn2 = F.scaled_dot_product_attention(
|
||||
q2, k2, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
|
||||
lambda_full = self._compute_lambda(q1)
|
||||
attn = attn1 - lambda_full * attn2
|
||||
attn = self._process_attention_output(attn, bsz, q_len)
|
||||
|
||||
# Save for logging
|
||||
self.attn1 = attn1
|
||||
self.attn2 = attn2
|
||||
self.lambda_full = lambda_full
|
||||
|
||||
return attn, None, past_key_value
|
||||
|
||||
|
||||
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
|
||||
"""
|
||||
Flash Attention 2-based implementation of differential attention.
|
||||
|
||||
This class implements differential attention using Flash Attention 2 for maximum
|
||||
performance on supported hardware.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Initializes the Flash Attention 2 differential attention module.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments passed to parent class.
|
||||
**kwargs: Keyword arguments passed to parent class.
|
||||
|
||||
Raises:
|
||||
ImportError: If flash-attn library is not installed.
|
||||
"""
|
||||
if not FLASH_ATTENTION_AVAILABLE:
|
||||
raise ImportError(
|
||||
"LlamaDifferentialFlashAttention2 requires flash-attn library. "
|
||||
"Please install with `pip install flash-attn --no-build-isolation`"
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: Cache | None = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""
|
||||
Computes differential attention using Flash Attention 2.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor containing sequence to attend to.
|
||||
attention_mask: Mask to avoid attention on padding tokens.
|
||||
position_ids: Indices of positions for positional embeddings.
|
||||
past_key_value: Cached key and value tensors for autoregressive decoding.
|
||||
output_attentions: Whether to return attention weights.
|
||||
use_cache: Whether to use cached key/value states.
|
||||
cache_position: Position indices for cached states.
|
||||
position_embeddings: Pre-computed positional embeddings.
|
||||
**kwargs: Additional arguments passed to the forward call.
|
||||
|
||||
Returns:
|
||||
tuple containing:
|
||||
- Output tensor after attention computation.
|
||||
- None for attention weights (Flash Attention doesn't support output_attentions).
|
||||
- Updated key-value cache if use_cache is True, else None.
|
||||
"""
|
||||
if output_attentions:
|
||||
LOG.warning(
|
||||
"LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but "
|
||||
+ "flash attenion does not support `output_attentions=True`. Falling back "
|
||||
+ "to the eager attention implementation."
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
return LlamaDifferentialAttention.forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
|
||||
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
|
||||
q1, q2, k1, k2, position_ids, position_embeddings
|
||||
)
|
||||
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
|
||||
|
||||
# Flash Attention specific processing
|
||||
q1, q2 = q1.transpose(1, 2), q2.transpose(1, 2)
|
||||
k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
dropout_p = self.config.attention_dropout if self.training else 0.0
|
||||
|
||||
if self.config.split_heads:
|
||||
v1, v2 = v.chunk(2, dim=-1)
|
||||
attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True)
|
||||
attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True)
|
||||
attn1 = torch.cat([attn11, attn12], dim=-1)
|
||||
|
||||
attn21 = flash_attn_func(q2, k2, v1, dropout_p=dropout_p, causal=True)
|
||||
attn22 = flash_attn_func(q2, k2, v2, dropout_p=dropout_p, causal=True)
|
||||
attn2 = torch.cat([attn21, attn22], dim=-1)
|
||||
else:
|
||||
attn1 = flash_attn_func(q1, k1, v, dropout_p=dropout_p, causal=True)
|
||||
attn2 = flash_attn_func(q2, k2, v, dropout_p=dropout_p, causal=True)
|
||||
|
||||
attn1, attn2 = attn1.transpose(1, 2), attn2.transpose(1, 2)
|
||||
|
||||
lambda_full = self._compute_lambda(q1)
|
||||
attn = attn1 - lambda_full * attn2
|
||||
attn = self._process_attention_output(attn, bsz, q_len)
|
||||
|
||||
# Save for logging
|
||||
self.attn1 = attn1
|
||||
self.attn2 = attn2
|
||||
self.lambda_full = lambda_full
|
||||
|
||||
return attn, None, past_key_value
|
||||
@@ -1,401 +0,0 @@
|
||||
"""
|
||||
Modeling for differential transformers.
|
||||
|
||||
This module implements differential attention variants of the LLaMA model,
|
||||
providing various attention implementations for improved performance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
from .diff_attn import (
|
||||
LlamaDifferentialAttention,
|
||||
LlamaDifferentialFlashAttention2,
|
||||
LlamaDifferentialSdpaAttention,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaDifferentialConfig(LlamaConfig):
|
||||
"""
|
||||
Configuration class for Differential LLaMA model.
|
||||
|
||||
Extends the base LLaMA configuration with additional parameters for differential
|
||||
attention mechanisms.
|
||||
"""
|
||||
|
||||
model_type = "llama-differential"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_heads: bool = False,
|
||||
sublayer_norm: bool = True,
|
||||
zero_init: bool = False,
|
||||
mirror_weights: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize differential LLaMA configuration.
|
||||
|
||||
Args:
|
||||
split_heads: Whether to use split heads mode for attention computation.
|
||||
sublayer_norm: Whether to apply normalization to sublayers.
|
||||
zero_init: Whether to initialize new weights to zero.
|
||||
mirror_weights: Whether to copy the positive attention component weights to
|
||||
the negative attention component.
|
||||
**kwargs: Additional arguments passed to LlamaConfig.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.split_heads = split_heads
|
||||
self.sublayer_norm = sublayer_norm
|
||||
self.zero_init = zero_init
|
||||
self.mirror_weights = mirror_weights
|
||||
self.architectures = ["LlamaDifferentialModel"]
|
||||
self._attn_implementations = {
|
||||
"eager": "differential_eager",
|
||||
"sdpa": "differential_sdpa",
|
||||
"flash_attention_2": "differential_flash_attention_2",
|
||||
}
|
||||
|
||||
|
||||
class LlamaDifferentialModel(LlamaModel):
|
||||
"""
|
||||
LlamaModel with differential attention.
|
||||
|
||||
This class extends the base LLaMA model by replacing standard attention with
|
||||
differential attention mechanisms.
|
||||
"""
|
||||
|
||||
config_class = LlamaDifferentialConfig
|
||||
base_model_prefix = "llama_differential"
|
||||
|
||||
def __init__(self, config: LlamaDifferentialConfig):
|
||||
"""
|
||||
Initialize a differential LLaMA model.
|
||||
|
||||
Args:
|
||||
config: Configuration object for the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If specified attention implementation is not supported.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# Handle attention implementation
|
||||
attn_impl = config._attn_implementation or "eager"
|
||||
if attn_impl in config._attn_implementations:
|
||||
attn_impl = config._attn_implementations[attn_impl]
|
||||
|
||||
# Validate attention implementation
|
||||
valid_impls = [
|
||||
None,
|
||||
"differential_eager",
|
||||
"differential_sdpa",
|
||||
"differential_flash_attention_2",
|
||||
]
|
||||
if attn_impl not in valid_impls:
|
||||
raise ValueError(f"Invalid attention implementation: {attn_impl}")
|
||||
|
||||
# Replace standard attention with differential attention in each layer
|
||||
attn_classes = {
|
||||
"differential_eager": LlamaDifferentialAttention,
|
||||
"differential_sdpa": LlamaDifferentialSdpaAttention,
|
||||
"differential_flash_attention_2": LlamaDifferentialFlashAttention2,
|
||||
}
|
||||
attn_class = attn_classes.get(attn_impl, LlamaDifferentialAttention)
|
||||
|
||||
for idx, layer in enumerate(self.layers):
|
||||
layer.self_attn = attn_class(config, idx)
|
||||
|
||||
@classmethod
|
||||
# pylint: disable=protected-access
|
||||
def _autoset_attn_implementation(
|
||||
cls,
|
||||
config: LlamaDifferentialConfig,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> LlamaDifferentialConfig:
|
||||
"""
|
||||
Automatically set the attention implementation based on config.
|
||||
|
||||
Args:
|
||||
config: Model configuration object.
|
||||
**kwargs: Additional arguments (unused).
|
||||
|
||||
Returns:
|
||||
Updated configuration object.
|
||||
|
||||
Raises:
|
||||
ValueError: If specified attention implementation is not supported.
|
||||
"""
|
||||
config._attn_implementation_autoset = True
|
||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
|
||||
# Map standard types to differential types if mapping exists
|
||||
if attn_implementation in config._attn_implementations:
|
||||
config._attn_implementation = config._attn_implementations[
|
||||
attn_implementation
|
||||
]
|
||||
return config
|
||||
|
||||
# If no mapping, validate it's a valid differential type
|
||||
valid_impls = [
|
||||
None,
|
||||
"differential_eager",
|
||||
"differential_sdpa",
|
||||
"differential_flash_attention_2",
|
||||
]
|
||||
if attn_implementation not in valid_impls:
|
||||
message = (
|
||||
f"Specified `attn_implementation={attn_implementation}` is not supported. "
|
||||
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_llama(
|
||||
cls,
|
||||
model: LlamaModel | LlamaForCausalLM,
|
||||
config: LlamaDifferentialConfig | None = None,
|
||||
) -> "LlamaDifferentialModel":
|
||||
"""
|
||||
Convert a `LlamaModel` to use differential attention.
|
||||
|
||||
Args:
|
||||
model: Base LLaMA model to convert.
|
||||
config: Configuration for differential attention. If `None`, created from
|
||||
base model config.
|
||||
|
||||
Returns:
|
||||
Converted model with differential attention.
|
||||
|
||||
Raises:
|
||||
ValueError: If number of heads is not even when using `split_heads` mode.
|
||||
"""
|
||||
logger.info(f"Converting {type(model).__name__} to {cls.__name__}")
|
||||
|
||||
# Handle LlamaForCausalLM
|
||||
if isinstance(model, LlamaForCausalLM):
|
||||
model = model.model
|
||||
|
||||
if config is None:
|
||||
config = LlamaDifferentialConfig(**model.config.__dict__)
|
||||
logger.debug(f"Created config: {config}")
|
||||
|
||||
# Validate head counts if using split heads mode
|
||||
if config.split_heads:
|
||||
if config.num_attention_heads % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Number of attention heads ({config.num_attention_heads}) must be even "
|
||||
"when using split_heads=True"
|
||||
)
|
||||
if config.num_key_value_heads % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
|
||||
"when using split_heads=True"
|
||||
)
|
||||
|
||||
new_model = cls(config)
|
||||
|
||||
# Copy all weights except attention
|
||||
logger.debug("Copying embeddings and norm")
|
||||
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
|
||||
new_model.norm.load_state_dict(model.norm.state_dict())
|
||||
|
||||
logger.debug("Copying layer weights")
|
||||
for layer_idx, (new_layer, old_layer) in enumerate(
|
||||
zip(new_model.layers, model.layers)
|
||||
):
|
||||
# Copy everything except attention weights
|
||||
new_layer.mlp.load_state_dict(old_layer.mlp.state_dict())
|
||||
new_layer.input_layernorm.load_state_dict(
|
||||
old_layer.input_layernorm.state_dict()
|
||||
)
|
||||
new_layer.post_attention_layernorm.load_state_dict(
|
||||
old_layer.post_attention_layernorm.state_dict()
|
||||
)
|
||||
|
||||
# Handle attention weights
|
||||
new_layer.self_attn.v_proj.load_state_dict(
|
||||
old_layer.self_attn.v_proj.state_dict()
|
||||
)
|
||||
new_layer.self_attn.o_proj.load_state_dict(
|
||||
old_layer.self_attn.o_proj.state_dict()
|
||||
)
|
||||
|
||||
# Get the original projection sizes
|
||||
old_q_size = old_layer.self_attn.q_proj.weight.size(0)
|
||||
old_k_size = old_layer.self_attn.k_proj.weight.size(0)
|
||||
|
||||
if not config.split_heads:
|
||||
logger.debug(
|
||||
f"Layer {layer_idx}: Copying Q/K projections with sizes {old_q_size}, {old_k_size}"
|
||||
)
|
||||
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
|
||||
old_layer.self_attn.q_proj.weight.data
|
||||
)
|
||||
new_layer.self_attn.k_proj.weight.data[:old_k_size].copy_(
|
||||
old_layer.self_attn.k_proj.weight.data
|
||||
)
|
||||
|
||||
if config.zero_init:
|
||||
logger.debug(f"Layer {layer_idx}: Zero initializing")
|
||||
with torch.no_grad():
|
||||
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
|
||||
new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_()
|
||||
new_layer.self_attn.lambda_q1.zero_()
|
||||
new_layer.self_attn.lambda_k1.zero_()
|
||||
new_layer.self_attn.lambda_q2.zero_()
|
||||
new_layer.self_attn.lambda_k2.zero_()
|
||||
new_layer.self_attn.lambda_init.zero_()
|
||||
elif config.mirror_weights:
|
||||
# Mirror weights for second component
|
||||
new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_(
|
||||
old_layer.self_attn.q_proj.weight.data
|
||||
)
|
||||
new_layer.self_attn.k_proj.weight.data[old_k_size:].copy_(
|
||||
old_layer.self_attn.k_proj.weight.data
|
||||
)
|
||||
|
||||
logger.info("Conversion complete")
|
||||
|
||||
return new_model
|
||||
|
||||
|
||||
class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
||||
"""
|
||||
`LlamaForCausalLM` with differential attention.
|
||||
|
||||
This class extends the base LLaMA causal language model by incorporating
|
||||
differential attention mechanisms.
|
||||
"""
|
||||
|
||||
config_class = LlamaDifferentialConfig
|
||||
base_model_prefix = "llama_differential"
|
||||
|
||||
def __init__(self, config: LlamaDifferentialConfig):
|
||||
"""
|
||||
Initialize a differential LLaMA model for causal language modeling.
|
||||
|
||||
Args:
|
||||
config: Configuration object for the model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.model = LlamaDifferentialModel(config)
|
||||
|
||||
@classmethod
|
||||
# pylint: disable=protected-access
|
||||
def _autoset_attn_implementation(
|
||||
cls,
|
||||
config: LlamaDifferentialConfig,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> LlamaDifferentialConfig:
|
||||
"""
|
||||
Automatically set the attention implementation based on config.
|
||||
|
||||
Args:
|
||||
config: Model configuration object.
|
||||
**kwargs: Additional arguments (unused).
|
||||
|
||||
Returns:
|
||||
Updated configuration object.
|
||||
|
||||
Raises:
|
||||
ValueError: If specified attention implementation is not supported.
|
||||
"""
|
||||
config._attn_implementation_autoset = True
|
||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
|
||||
# Map standard types to differential types if mapping exists
|
||||
if attn_implementation in config._attn_implementations:
|
||||
config._attn_implementation = config._attn_implementations[
|
||||
attn_implementation
|
||||
]
|
||||
|
||||
return config
|
||||
|
||||
# If no mapping, validate it's a valid differential type
|
||||
valid_impls = [
|
||||
None,
|
||||
"differential_eager",
|
||||
"differential_sdpa",
|
||||
"differential_flash_attention_2",
|
||||
]
|
||||
if attn_implementation not in valid_impls:
|
||||
message = (
|
||||
f"Specified `attn_implementation={attn_implementation}` is not supported. "
|
||||
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_llama(
|
||||
cls, model: LlamaForCausalLM, config: LlamaDifferentialConfig | None = None
|
||||
) -> "LlamaDifferentialForCausalLM":
|
||||
"""
|
||||
Convert a `LlamaForCausalLM` to use differential attention.
|
||||
|
||||
Args:
|
||||
model: Base LLaMA model to convert.
|
||||
config: Configuration for differential attention. If `None`, created from
|
||||
base model config.
|
||||
|
||||
Returns:
|
||||
Converted model with differential attention.
|
||||
|
||||
Raises:
|
||||
ValueError: If number of heads is not even when using `split_heads` mode.
|
||||
"""
|
||||
if config is None:
|
||||
config = LlamaDifferentialConfig(**model.config.__dict__)
|
||||
|
||||
# Validate head counts if using split heads mode
|
||||
if config.split_heads:
|
||||
if config.num_attention_heads % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Number of attention heads ({config.num_attention_heads}) must be even "
|
||||
"when using split_heads=True"
|
||||
)
|
||||
if config.num_key_value_heads % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
|
||||
"when using split_heads=True"
|
||||
)
|
||||
|
||||
new_model = cls(config)
|
||||
new_model.model = LlamaDifferentialModel.from_llama(model.model, config)
|
||||
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
|
||||
|
||||
return new_model
|
||||
|
||||
|
||||
def register_diff_attn() -> None:
|
||||
"""
|
||||
Register differential attention components with the transformers library.
|
||||
|
||||
This function registers the differential attention configurations and model classes
|
||||
with the Auto* classes from `transformers`, making them available through the
|
||||
standard model loading pipeline.
|
||||
"""
|
||||
# Register configs
|
||||
AutoConfig.register("llama-differential", LlamaDifferentialConfig)
|
||||
|
||||
# Register models
|
||||
AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel)
|
||||
AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM)
|
||||
|
||||
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
|
||||
|
||||
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
|
||||
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
|
||||
LLAMA_ATTENTION_CLASSES[
|
||||
"differential_flash_attention_2"
|
||||
] = LlamaDifferentialFlashAttention2
|
||||
Reference in New Issue
Block a user