Compare commits

...

3 Commits

Author SHA1 Message Date
Dan Saunders
66262c3092 moving out all diff attn code to plugin repo 2025-01-24 17:46:11 +00:00
Dan Saunders
016ba124e4 README update 2025-01-23 22:11:35 +00:00
Dan Saunders
7145d52d99 moving diff attn code to separate repo 2025-01-23 21:33:53 +00:00
15 changed files with 50 additions and 1907 deletions

View File

@@ -1,208 +0,0 @@
"""CLI to convert a transformers model's attention layers to differential attention layers."""
import logging
import warnings
from pathlib import Path
from time import time
from typing import Union
import fire
import torch
import yaml
from colorama import Fore
from dotenv import load_dotenv
from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
LlamaDifferentialConfig,
LlamaDifferentialForCausalLM,
)
from axolotl.utils.yaml import dump_yaml_preserved_order
LOG = logging.getLogger(__name__)
def test_inference(model, tokenizer, prompt="The quick brown fox"):
"""Run test inference and return generation time"""
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()}
start = time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
)
elapsed = time() - start
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
LOG.info("Prompt: %s", prompt)
LOG.info("Generated: %s", generated_text)
LOG.info("Generation time: %.2fs", elapsed)
return elapsed, generated_text
def convert_diff_transformer(cfg, cli_args, config_path):
assert not (
cli_args.split_heads and cli_args.zero_init
), "Both `split_heads` and `zero_init` cannot be `True`"
assert not (
cli_args.zero_init and cli_args.mirror_weights
), "Both `zero_init` and `mirror_weights` cannot be `True`"
debug_info = {}
# Load model and tokenizer
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model.to(cfg.device, dtype=cfg.torch_dtype)
# Log original model info
LOG.info(
"Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d",
model.config.hidden_size,
model.config.num_attention_heads,
)
# Test original model
if cli_args.debug:
LOG.info("Testing original model...")
debug_info["orig_time"], debug_info["orig_text"] = test_inference(
model, tokenizer
)
try:
# Convert attention
LOG.info("Converting to differential attention...")
config = LlamaDifferentialConfig(
**model.config.__dict__,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
mirror_weights=cli_args.mirror_weights,
)
model = LlamaDifferentialForCausalLM.from_llama(model, config)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise
# Test converted model
if cli_args.debug:
LOG.info("Testing converted model...")
debug_info["conv_time"], debug_info["conv_text"] = test_inference(
model, tokenizer
)
# Save if requested
if cfg.output_dir:
# Save model and tokenizer
LOG.info("Saving converted model to %s", cfg.output_dir)
model.save_pretrained(cfg.output_dir)
tokenizer.save_pretrained(cfg.output_dir)
# Modify config to reflect new path / differential attention
output_config_path = Path(cfg.output_dir) / "axolotl_config.yml"
LOG.info("Saving updated config to %s", output_config_path)
with open(config_path, "r", encoding="utf-8") as file:
modified_cfg = yaml.safe_load(file) or {}
modified_cfg["base_model"] = cfg.output_dir
modified_cfg["diff_attention"] = True
plugin_class = (
"axolotl.integrations.diff_transformer.DifferentialTransformerPlugin"
)
if "plugins" in modified_cfg:
modified_cfg["plugins"].append(plugin_class)
else:
modified_cfg["plugins"] = [plugin_class]
# Write out the updated axolotl config while preserving original ordering / formatting
dump_yaml_preserved_order(
data=modified_cfg,
reference_yaml_path=config_path,
output_path=output_config_path,
)
else:
LOG.info("Not saving converted model to disk")
LOG.info("Pass --output-dir path/to/save to save model")
if cli_args.debug:
LOG.info(
Fore.GREEN
+ "Conversion successful!\n"
+ f"Original generation time: {debug_info['orig_time']:.2f}s\n"
+ f"Converted generation time: {debug_info['conv_time']:.2f}s"
+ Fore.RESET
)
if debug_info["orig_text"] == debug_info["conv_text"]:
LOG.info(
Fore.GREEN
+ "Generations match!\n"
+ "Model generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ Fore.RESET
)
debug_info["generations_match"] = True
else:
message = (
"Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['conv_text']}\n"
+ "*" * 50
+ "\n"
)
debug_info["generations_match"] = False
if cli_args.zero_init and not cli_args.sublayer_norm:
LOG.info(Fore.RED + message + Fore.RESET)
debug_info["match_expected"] = True
else:
LOG.info(
Fore.YELLOW
+ message
+ "However, this is expected since --zero-init"
+ " and --no-sublayer-norm were not passed."
+ Fore.RESET
)
debug_info["match_expected"] = False
return model, debug_info
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
print_axolotl_text_art()
cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
convert_diff_transformer(cfg, cli_args, config)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,23 +1,20 @@
"""CLI definition for various axolotl commands."""
# pylint: disable=redefined-outer-name
import subprocess # nosec B404
from typing import Optional
import click
import axolotl
from axolotl.cli.plugins import setup_plugin_commands
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
build_command,
fetch_from_github,
)
from axolotl.common.cli import (
ConvertDiffTransformerCliArgs,
EvaluateCliArgs,
PreprocessCliArgs,
TrainerCliArgs,
)
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -248,19 +245,6 @@ def merge_lora(
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.integrations.convert_diff_transformer import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")
@@ -275,6 +259,9 @@ def fetch(directory: str, dest: Optional[str]):
fetch_from_github(f"{directory}/", dest)
setup_plugin_commands(cli)
def main():
cli()

View File

@@ -0,0 +1,36 @@
"""Module for adding click CLI commands from axolotl plugins."""
import logging
import click
from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass
from axolotl.logging_config import configure_logging
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
configure_logging()
LOG = logging.getLogger(__name__)
def setup_plugin_commands(cli: click.core.Group) -> None:
"""
Setup CLI commands for available plugins.
Args:
cli: Click CLI object to add plugin CLI options to.
"""
try:
from axolotl_diff_transformer.convert_diff_transformer import do_cli
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
do_cli(config=config, **kwargs)
except ImportError as exc:
LOG.debug("axolotl-diff-transformer not found: %s", exc)

View File

@@ -4,13 +4,19 @@ shared module for cli specific things
import logging
from dataclasses import dataclass, field
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
if TYPE_CHECKING:
try:
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
except: # noqa: E722 # pylint: disable=bare-except # nosec B110
pass
configure_logging()
LOG = logging.getLogger(__name__)
@@ -48,21 +54,10 @@ class EvaluateCliArgs:
debug_num_examples: int = field(default=0)
@dataclass
class ConvertDiffTransformerCliArgs:
"""dataclass with arguments for convert-diff-transformer CLI"""
debug: bool = field(default=False)
zero_init: bool = field(default=False)
sublayer_norm: bool = field(default=True)
split_heads: bool = field(default=False)
mirror_weights: bool = field(default=False)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: Union[TrainerCliArgs, EvaluateCliArgs, ConvertDiffTransformerCliArgs],
cli_args: Union[TrainerCliArgs, EvaluateCliArgs, "ConvertDiffTransformerCliArgs"],
):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)

View File

@@ -1,12 +0,0 @@
# Differential Transformer
### Usage
**Note:** The following with be set in the model config output by the `axolotl convert-diff-transformer` command.
```yaml
plugins:
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
diff_attention: true
```

View File

@@ -1,67 +0,0 @@
"""Definition of differential transformer plugin."""
import logging
from typing import List
from transformers import PreTrainedModel, TrainerCallback
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.diff_attn import (
DifferentialAttentionMixingCallback,
DifferentialAttentionMonitorCallback,
)
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
class DifferentialTransformerPlugin(BasePlugin):
"""Plugin for differential transformer integration with Axolotl."""
def __init__(self) -> None:
"""
Constructor for differential transformers plugin. Calls `register_diff_attn`
to register differential attention custom modeling implementation to `AutoConfig`
and `AutoModel`.
"""
from .modeling_diff_attn import register_diff_attn
register_diff_attn()
def get_input_args(self) -> str:
"""Returns module path to diff transformer plugin args for `axolotl` config."""
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
# pylint: disable=unused-argument
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> List[TrainerCallback]:
"""
Returns `DifferentialAttentionMonitorCallback` to be added to the list of
callbacks for the `axolotl` trainer if wandb usage is enabled.
Parameters:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The loaded mfodel.
Returns:
A list (possibly) containing an instantiated `DifferentialAttentionMonitorCallback`.
"""
callbacks = []
if cfg.use_wandb:
callbacks.append(
DifferentialAttentionMonitorCallback(
log_every=cfg.diff_attn_log_every,
num_monitor_layers=cfg.diff_attn_num_monitor_layers,
warmup_steps=cfg.diff_attn_warmup_steps,
)
)
if cfg.diff_attn_warmup_steps:
callbacks.append(
DifferentialAttentionMixingCallback(
warmup_steps=cfg.diff_attn_warmup_steps
)
)
return callbacks

View File

@@ -1,27 +0,0 @@
"""Module for handling differential transfomer input arguments."""
import logging
from typing import Optional
from pydantic import BaseModel
LOG = logging.getLogger(__name__)
class DifferentialTransformerArgs(BaseModel):
"""
Input args for differential transformer.
Attributes:
diff_attention: Whether to use differential attention layers.
diff_attn_log_every: How often to log differential attention statistics.
diff_attn_num_monitor_layers: Number of layers to monitor for attention stats.
diff_attn_warmup_steps: Number of steps to linearly increase negative attention
mixing weight from 0 to 1. If specified, will reach full mixing at this
step. If `None`, negative attention has full weight from the start.
"""
diff_attention: Optional[bool] = None
diff_attn_log_every: Optional[int] = 100
diff_attn_num_monitor_layers: Optional[int] = 3
diff_attn_warmup_steps: Optional[int] = None

View File

@@ -1,694 +0,0 @@
"""Re-implemention of differential attention from the Differential Transformer paper
(https://arxiv.org/abs/2410.05258)."""
# pylint: disable=invalid-name
import logging
import math
from typing import Any
import torch
import torch.nn.functional as F
from torch import nn
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
)
logging.basicConfig(level=logging.INFO)
LOG = logging.getLogger(__name__)
try:
from flash_attn.flash_attn_interface import flash_attn_func
FLASH_ATTENTION_AVAILABLE = True
except ImportError:
FLASH_ATTENTION_AVAILABLE = False
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Repeats key/value heads to match the number of query heads in multi-head attention.
Args:
x: Input tensor of shape `(batch_size, num_kv_heads, seq_len, head_dim)`.
n_rep: Number of times to repeat each head.
Returns:
Tensor with repeated heads of shape `(batch_size, num_kv_heads * n_rep,
seq_len, head_dim)`.
If `n_rep` is 1, returns the input tensor unchanged.
"""
batch_size, n_kv_heads, slen, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, None, :, :]
.expand(batch_size, n_kv_heads, n_rep, slen, head_dim)
.reshape(batch_size, n_kv_heads * n_rep, slen, head_dim)
)
def lambda_init_fn(depth: int) -> float:
"""
Lambda mixing parameter init function from the "Differential Transformer" paper.
Args:
depth: Index of layer to init lambda parameter.
Returns:
Lambda initialization value (decreasing with `depth`).
"""
return 0.8 - 0.6 * math.exp(-0.3 * depth)
class LlamaDifferentialAttentionBase(nn.Module):
"""
Base class for differential attention implementations.
This class implements the core differential attention mechanism used in Llama models.
It supports both split heads and double projection modes for attention computation.
"""
def __init__(self, config: Any, layer_idx: int):
"""
Initializes the differential attention module.
Args:
config: Model configuration object containing hyperparameters, including:
- hidden_size: The size of hidden states.
- num_attention_heads: Number of attention heads.
- num_key_value_heads: Number of key/value heads.
- attention_bias: Whether to use bias in attention projections.
- split_heads: Whether to use split heads mode.
- rms_norm_eps: Epsilon for RMS normalization.
layer_idx: The index of this layer in the model.
Note:
The initialization process consists of four steps:
1. Configuration initialization (`_init_config`)
2. Projection layers initialization (`_init_projections`)
3. Differential parameters initialization (`_init_differential_params`)
4. Normalization layers initialization (`_init_normalization`)
"""
super().__init__()
self.config = config
self._init_config(layer_idx)
self._init_projections()
self._init_differential_params()
self._init_normalization()
# For logging
self.attn1 = None
self.attn2 = None
self.lambda_full = None
def _init_config(self, layer_idx: int) -> None:
"""
Initializes configuration parameters for the attention layer. Sets up various
dimension sizes and head counts based on the provided config. Handles both
split heads and double projection modes.
In split heads mode, the number of heads is divided by 2 (rounding down), which
differs from the original implementation that required an even number.
Args:
layer_idx: Index of the current layer.
"""
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
self.base_num_heads = self.config.num_attention_heads
self.base_num_kv_heads = self.config.num_key_value_heads
self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads
self.layer_idx = layer_idx
if self.config.split_heads:
self.heads_per_component = self.base_num_heads // 2
self.kv_heads_per_component = self.base_num_kv_heads // 2
self.value_head_dim = 2 * self.head_dim
else:
self.heads_per_component = self.base_num_heads
self.kv_heads_per_component = self.base_num_kv_heads
self.value_head_dim = self.head_dim
def _init_projections(self) -> None:
"""
Initializes the query, key, value, and output projection layers.
Creates linear transformations for Q, K, V projections with dimensions
depending on whether split heads or double projection mode is used.
The output projection combines the attention heads back to model dimension.
"""
if self.config.split_heads:
q_out_dim = self.config.hidden_size
k_out_dim = self.head_dim * self.base_num_kv_heads
else:
q_out_dim = self.config.hidden_size * 2
k_out_dim = self.head_dim * self.base_num_kv_heads * 2
self.q_proj = nn.Linear(
self.config.hidden_size, q_out_dim, bias=self.config.attention_bias
)
self.k_proj = nn.Linear(
self.config.hidden_size, k_out_dim, bias=self.config.attention_bias
)
self.v_proj = nn.Linear(
self.config.hidden_size,
self.head_dim * self.base_num_kv_heads,
bias=self.config.attention_bias,
)
self.o_proj = nn.Linear(
self.base_num_heads * self.head_dim,
self.config.hidden_size,
bias=self.config.attention_bias,
)
def _init_differential_params(self) -> None:
"""
Initializes parameters specific to differential attention.
Creates learnable parameters for the differential attention mechanism:
- Mixing parameter for negative attention component warmup phase.
- Lambda parameters for queries and keys.
- Initial lambda value based on layer index.
- Rotary position embedding layer.
"""
self.diff_attn_mix = 1.0 # Default to full mixing
self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx)),
requires_grad=False,
)
self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_k1 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_q2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
def _init_normalization(self) -> None:
"""
Initializes normalization layers for the attention mechanism.
Sets up either RMS normalization or identity transformation based on config.
The normalization is applied to the sublayer output if enabled.
"""
sublayer_norm = getattr(self.config, "sublayer_norm", True)
if sublayer_norm:
self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps)
else:
self.subln = nn.Identity()
def _prepare_attention_inputs(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepares input tensors for attention computation.
Projects input hidden states to query, key, and value spaces, then reshapes
them for multi-head attention processing.
Args:
hidden_states: Input tensor of shape `(batch_size, seq_len,
hidden_size)`.
Returns:
tuple: Tuple containing:
- q1: Positive attention query component
- q2: Negative attention query component
- k1: Positive attention key component
- k2: Negative attention key component
- v: Value tensor
"""
bsz, q_len, _ = hidden_states.size()
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q1, q2 = q.chunk(2, dim=-1)
k1, k2 = k.chunk(2, dim=-1)
q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
k1 = k1.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
1, 2
)
k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
1, 2
)
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
return q1, q2, k1, k2, v
def _apply_rotary_embeddings(
self,
q1: torch.Tensor,
q2: torch.Tensor,
k1: torch.Tensor,
k2: torch.Tensor,
position_ids: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
Applies rotary positional embeddings to queries and keys.
Args:
q1: Positive attention query component.
q2: Negative attention query component.
k1: Positive attention key component.
k2: Negative attention key component.
position_ids: Token position indices.
position_embeddings: Pre-computed rotary embeddings (cos, sin).
Returns:
tuple: Tuple containing:
- q1: Positive attention query with positional encoding.
- q2: Negative attention query with positional encoding.
- k1: Positive attention key with positional encoding.
- k2: Negative attention key with positional encoding.
- cos: Cosine part of rotary embeddings.
- sin: Sine part of rotary embeddings.
"""
if position_embeddings is None:
LOG.warning(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(q1, position_ids)
else:
cos, sin = position_embeddings
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
return q1, q2, k1, k2, cos, sin
def _handle_cache(
self,
k1: torch.Tensor,
k2: torch.Tensor,
v: torch.Tensor,
past_key_value: Cache | None,
cache_kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Handles key-value caching for autoregressive generation and the repetition of
key-value heads to match the number of query heads.
Args:
k1: Positive attention key component.
k2: Negative attention key component.
v: Value tensor.
past_key_value: Cache object for storing previous key-value pairs.
cache_kwargs: Additional arguments for cache handling.
Returns:
tuple: Tuple containing:
- k1: Processed positive attention key component.
- k2: Processed negative attention key component.
- v: Processed value tensor.
"""
if past_key_value is not None:
k = torch.stack([k1, k2], dim=1)
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1)
k1 = repeat_kv(k1, self.num_key_value_groups)
k2 = repeat_kv(k2, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
if self.config.split_heads:
v = torch.cat(torch.chunk(v, 2, dim=1), dim=-1)
return k1, k2, v
def _compute_lambda(self, q1: torch.Tensor) -> torch.Tensor:
"""
Computes lambda values for differential attention.
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are computed
from the learned parameters. `diff_attn_mix` is multiplied through the result
for negative attention component warmup phase (if applicable).
Args:
q1: Positive attention query component, used for type casting.
Returns:
Computed lambda value for differential attention.
"""
lambda_1 = torch.exp(
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
).type_as(q1)
lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
return self.diff_attn_mix * lambda_full
def _process_attention_output(
self, attn: torch.Tensor, bsz: int, q_len: int
) -> torch.Tensor:
"""
Processes and projects the attention output. Applies sublayer normalization,
scales by (1 - λ_init), and projects back to model dimension.
Args:
attn: Raw attention output.
bsz: Batch size.
q_len: Query sequence length.
Returns:
Processed attention output of shape (batch_size, seq_len, hidden_size)
"""
attn = self.subln(attn)
# NOTE: this may need to be added back in, but doesn't interact well with
# `diff_attn_mix`, and doesn't allow us to preserve the original model output.
# attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
return self.o_proj(attn)
class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
"""
Standard implementation of differential attention.
This class implements the standard differential attention mechanism using
explicit matrix multiplications for the attention computation.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using standard matrix multiplication operations.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- Attention weights if output_attentions is True, else None.
- Updated key-value cache if use_cache is True, else None.
"""
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# Standard attention computation
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
attn1 = attn1 + causal_mask
attn2 = attn2 + causal_mask
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
dropout_p = self.config.attention_dropout if self.training else 0.0
attn1 = F.dropout(attn1, p=dropout_p, training=self.training)
attn2 = F.dropout(attn2, p=dropout_p, training=self.training)
lambda_full = self._compute_lambda(q1)
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
if output_attentions:
attn_weights = attn1 - lambda_full * attn2
attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1)
return attn, attn_weights, past_key_value
return attn, None, past_key_value
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
"""
SDPA-based implementation of differential attention.
This class implements differential attention using PyTorch's scaled_dot_product_attention
for improved performance on supported hardware.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using PyTorch's scaled dot product attention.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- None for attention weights (SDPA doesn't support output_attentions).
- Updated key-value cache if use_cache is True, else None.
"""
if output_attentions:
LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but "
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
+ "`output_attentions=True`. Falling back to the eager attention implementation."
)
# pylint: disable=duplicate-code
return LlamaDifferentialAttention.forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# SDPA-specific attention computation
causal_mask = (
None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]]
)
is_causal = attention_mask is None and q_len > 1
dropout_p = self.config.attention_dropout if self.training else 0.0
if q1.device.type == "cuda" and causal_mask is not None:
q1, q2 = q1.contiguous(), q2.contiguous()
k1, k2 = k1.contiguous(), k2.contiguous()
v = v.contiguous()
attn1 = F.scaled_dot_product_attention(
q1, k1, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
)
attn2 = F.scaled_dot_product_attention(
q2, k2, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
)
lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
return attn, None, past_key_value
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
"""
Flash Attention 2-based implementation of differential attention.
This class implements differential attention using Flash Attention 2 for maximum
performance on supported hardware.
"""
def __init__(self, *args, **kwargs):
"""
Initializes the Flash Attention 2 differential attention module.
Args:
*args: Positional arguments passed to parent class.
**kwargs: Keyword arguments passed to parent class.
Raises:
ImportError: If flash-attn library is not installed.
"""
if not FLASH_ATTENTION_AVAILABLE:
raise ImportError(
"LlamaDifferentialFlashAttention2 requires flash-attn library. "
"Please install with `pip install flash-attn --no-build-isolation`"
)
super().__init__(*args, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using Flash Attention 2.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- None for attention weights (Flash Attention doesn't support output_attentions).
- Updated key-value cache if use_cache is True, else None.
"""
if output_attentions:
LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but "
+ "flash attenion does not support `output_attentions=True`. Falling back "
+ "to the eager attention implementation."
)
# pylint: disable=duplicate-code
return LlamaDifferentialAttention.forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# Flash Attention specific processing
q1, q2 = q1.transpose(1, 2), q2.transpose(1, 2)
k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2)
v = v.transpose(1, 2)
dropout_p = self.config.attention_dropout if self.training else 0.0
if self.config.split_heads:
v1, v2 = v.chunk(2, dim=-1)
attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True)
attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = flash_attn_func(q2, k2, v1, dropout_p=dropout_p, causal=True)
attn22 = flash_attn_func(q2, k2, v2, dropout_p=dropout_p, causal=True)
attn2 = torch.cat([attn21, attn22], dim=-1)
else:
attn1 = flash_attn_func(q1, k1, v, dropout_p=dropout_p, causal=True)
attn2 = flash_attn_func(q2, k2, v, dropout_p=dropout_p, causal=True)
attn1, attn2 = attn1.transpose(1, 2), attn2.transpose(1, 2)
lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
return attn, None, past_key_value

View File

@@ -1,401 +0,0 @@
"""
Modeling for differential transformers.
This module implements differential attention variants of the LLaMA model,
providing various attention implementations for improved performance.
"""
import logging
import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
from .diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
logger = logging.getLogger(__name__)
class LlamaDifferentialConfig(LlamaConfig):
"""
Configuration class for Differential LLaMA model.
Extends the base LLaMA configuration with additional parameters for differential
attention mechanisms.
"""
model_type = "llama-differential"
def __init__(
self,
split_heads: bool = False,
sublayer_norm: bool = True,
zero_init: bool = False,
mirror_weights: bool = False,
**kwargs,
):
"""
Initialize differential LLaMA configuration.
Args:
split_heads: Whether to use split heads mode for attention computation.
sublayer_norm: Whether to apply normalization to sublayers.
zero_init: Whether to initialize new weights to zero.
mirror_weights: Whether to copy the positive attention component weights to
the negative attention component.
**kwargs: Additional arguments passed to LlamaConfig.
"""
super().__init__(**kwargs)
self.split_heads = split_heads
self.sublayer_norm = sublayer_norm
self.zero_init = zero_init
self.mirror_weights = mirror_weights
self.architectures = ["LlamaDifferentialModel"]
self._attn_implementations = {
"eager": "differential_eager",
"sdpa": "differential_sdpa",
"flash_attention_2": "differential_flash_attention_2",
}
class LlamaDifferentialModel(LlamaModel):
"""
LlamaModel with differential attention.
This class extends the base LLaMA model by replacing standard attention with
differential attention mechanisms.
"""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def __init__(self, config: LlamaDifferentialConfig):
"""
Initialize a differential LLaMA model.
Args:
config: Configuration object for the model.
Raises:
ValueError: If specified attention implementation is not supported.
"""
super().__init__(config)
# Handle attention implementation
attn_impl = config._attn_implementation or "eager"
if attn_impl in config._attn_implementations:
attn_impl = config._attn_implementations[attn_impl]
# Validate attention implementation
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_impl not in valid_impls:
raise ValueError(f"Invalid attention implementation: {attn_impl}")
# Replace standard attention with differential attention in each layer
attn_classes = {
"differential_eager": LlamaDifferentialAttention,
"differential_sdpa": LlamaDifferentialSdpaAttention,
"differential_flash_attention_2": LlamaDifferentialFlashAttention2,
}
attn_class = attn_classes.get(attn_impl, LlamaDifferentialAttention)
for idx, layer in enumerate(self.layers):
layer.self_attn = attn_class(config, idx)
@classmethod
# pylint: disable=protected-access
def _autoset_attn_implementation(
cls,
config: LlamaDifferentialConfig,
**kwargs, # pylint: disable=unused-argument
) -> LlamaDifferentialConfig:
"""
Automatically set the attention implementation based on config.
Args:
config: Model configuration object.
**kwargs: Additional arguments (unused).
Returns:
Updated configuration object.
Raises:
ValueError: If specified attention implementation is not supported.
"""
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)
# Map standard types to differential types if mapping exists
if attn_implementation in config._attn_implementations:
config._attn_implementation = config._attn_implementations[
attn_implementation
]
return config
# If no mapping, validate it's a valid differential type
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message)
return config
@classmethod
def from_llama(
cls,
model: LlamaModel | LlamaForCausalLM,
config: LlamaDifferentialConfig | None = None,
) -> "LlamaDifferentialModel":
"""
Convert a `LlamaModel` to use differential attention.
Args:
model: Base LLaMA model to convert.
config: Configuration for differential attention. If `None`, created from
base model config.
Returns:
Converted model with differential attention.
Raises:
ValueError: If number of heads is not even when using `split_heads` mode.
"""
logger.info(f"Converting {type(model).__name__} to {cls.__name__}")
# Handle LlamaForCausalLM
if isinstance(model, LlamaForCausalLM):
model = model.model
if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__)
logger.debug(f"Created config: {config}")
# Validate head counts if using split heads mode
if config.split_heads:
if config.num_attention_heads % 2 != 0:
raise ValueError(
f"Number of attention heads ({config.num_attention_heads}) must be even "
"when using split_heads=True"
)
if config.num_key_value_heads % 2 != 0:
raise ValueError(
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
"when using split_heads=True"
)
new_model = cls(config)
# Copy all weights except attention
logger.debug("Copying embeddings and norm")
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
new_model.norm.load_state_dict(model.norm.state_dict())
logger.debug("Copying layer weights")
for layer_idx, (new_layer, old_layer) in enumerate(
zip(new_model.layers, model.layers)
):
# Copy everything except attention weights
new_layer.mlp.load_state_dict(old_layer.mlp.state_dict())
new_layer.input_layernorm.load_state_dict(
old_layer.input_layernorm.state_dict()
)
new_layer.post_attention_layernorm.load_state_dict(
old_layer.post_attention_layernorm.state_dict()
)
# Handle attention weights
new_layer.self_attn.v_proj.load_state_dict(
old_layer.self_attn.v_proj.state_dict()
)
new_layer.self_attn.o_proj.load_state_dict(
old_layer.self_attn.o_proj.state_dict()
)
# Get the original projection sizes
old_q_size = old_layer.self_attn.q_proj.weight.size(0)
old_k_size = old_layer.self_attn.k_proj.weight.size(0)
if not config.split_heads:
logger.debug(
f"Layer {layer_idx}: Copying Q/K projections with sizes {old_q_size}, {old_k_size}"
)
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
old_layer.self_attn.q_proj.weight.data
)
new_layer.self_attn.k_proj.weight.data[:old_k_size].copy_(
old_layer.self_attn.k_proj.weight.data
)
if config.zero_init:
logger.debug(f"Layer {layer_idx}: Zero initializing")
with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_()
new_layer.self_attn.lambda_q1.zero_()
new_layer.self_attn.lambda_k1.zero_()
new_layer.self_attn.lambda_q2.zero_()
new_layer.self_attn.lambda_k2.zero_()
new_layer.self_attn.lambda_init.zero_()
elif config.mirror_weights:
# Mirror weights for second component
new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_(
old_layer.self_attn.q_proj.weight.data
)
new_layer.self_attn.k_proj.weight.data[old_k_size:].copy_(
old_layer.self_attn.k_proj.weight.data
)
logger.info("Conversion complete")
return new_model
class LlamaDifferentialForCausalLM(LlamaForCausalLM):
"""
`LlamaForCausalLM` with differential attention.
This class extends the base LLaMA causal language model by incorporating
differential attention mechanisms.
"""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def __init__(self, config: LlamaDifferentialConfig):
"""
Initialize a differential LLaMA model for causal language modeling.
Args:
config: Configuration object for the model.
"""
super().__init__(config)
self.model = LlamaDifferentialModel(config)
@classmethod
# pylint: disable=protected-access
def _autoset_attn_implementation(
cls,
config: LlamaDifferentialConfig,
**kwargs, # pylint: disable=unused-argument
) -> LlamaDifferentialConfig:
"""
Automatically set the attention implementation based on config.
Args:
config: Model configuration object.
**kwargs: Additional arguments (unused).
Returns:
Updated configuration object.
Raises:
ValueError: If specified attention implementation is not supported.
"""
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)
# Map standard types to differential types if mapping exists
if attn_implementation in config._attn_implementations:
config._attn_implementation = config._attn_implementations[
attn_implementation
]
return config
# If no mapping, validate it's a valid differential type
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message)
return config
@classmethod
def from_llama(
cls, model: LlamaForCausalLM, config: LlamaDifferentialConfig | None = None
) -> "LlamaDifferentialForCausalLM":
"""
Convert a `LlamaForCausalLM` to use differential attention.
Args:
model: Base LLaMA model to convert.
config: Configuration for differential attention. If `None`, created from
base model config.
Returns:
Converted model with differential attention.
Raises:
ValueError: If number of heads is not even when using `split_heads` mode.
"""
if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__)
# Validate head counts if using split heads mode
if config.split_heads:
if config.num_attention_heads % 2 != 0:
raise ValueError(
f"Number of attention heads ({config.num_attention_heads}) must be even "
"when using split_heads=True"
)
if config.num_key_value_heads % 2 != 0:
raise ValueError(
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
"when using split_heads=True"
)
new_model = cls(config)
new_model.model = LlamaDifferentialModel.from_llama(model.model, config)
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
return new_model
def register_diff_attn() -> None:
"""
Register differential attention components with the transformers library.
This function registers the differential attention configurations and model classes
with the Auto* classes from `transformers`, making them available through the
standard model loading pipeline.
"""
# Register configs
AutoConfig.register("llama-differential", LlamaDifferentialConfig)
# Register models
AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel)
AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM)
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
LLAMA_ATTENTION_CLASSES[
"differential_flash_attention_2"
] = LlamaDifferentialFlashAttention2

View File

@@ -1,234 +0,0 @@
"""
Monitor and log differential attention components during training.
This module provides a callback for tracking the behavior of differential attention
mechanisms, including lambda parameters and attention statistics.
"""
from typing import Any
import torch
import wandb
from torch import nn
from transformers import TrainerCallback
from axolotl.utils.distributed import is_main_process
class DifferentialAttentionMonitorCallback(TrainerCallback):
"""
Callback to monitor differential attention components and lambda parameters.
This callback tracks attention statistics across all layers and provides detailed
monitoring for a specified number of layers evenly spaced through the model.
"""
def __init__(
self,
log_every: int = 250,
num_monitor_layers: int = 3,
warmup_steps: int | None = None,
):
"""
Initialize the differential attention monitor.
Args:
log_every: Number of steps between logging events.
num_monitor_layers: Number of individual layers to monitor in detail.
warmup_steps: Optional parameter for negative attention component warmup.
"""
self.log_every = log_every
self.num_monitor_layers = num_monitor_layers
self.warmup_steps = warmup_steps
self.monitor_layers: list[int] | None = None # Will be set in on_train_begin
# pylint: disable=unused-argument
def on_train_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module,
**kwargs,
) -> None:
"""
Set up layer monitoring at the start of training.
Args:
args: Training arguments.
state: Training state.
control: Training control object.
model: The model being trained.
**kwargs: Additional arguments passed by the trainer.
"""
if is_main_process():
num_layers = len(model.model.layers)
self.num_monitor_layers = min(self.num_monitor_layers, num_layers)
stride = (
(num_layers - 1) / (self.num_monitor_layers - 1)
if self.num_monitor_layers > 1
else 0
)
self.monitor_layers = [
round(i * stride) for i in range(self.num_monitor_layers)
]
print(f"Monitoring layers {self.monitor_layers} in detail")
# pylint: disable=unused-argument
def on_step_end(
self, args: Any, state: Any, control: Any, model: torch.nn.Module, **kwargs
) -> None:
"""
Log attention metrics at the end of each step.
Collects and logs:
- Lambda parameter norms and values.
- Attention statistics (mean and std).
- Both per-layer and aggregate metrics.
Args:
args: Training arguments.
state: Training state.
control: Training control object.
model: The model being trained.
**kwargs: Additional arguments passed by the trainer.
"""
if not is_main_process() or state.global_step % self.log_every != 0:
return
assert self.monitor_layers is not None
# Aggregate stats across all layers
all_q1_norms = []
all_q2_norms = []
all_k1_norms = []
all_k2_norms = []
all_lambda1 = []
all_lambda2 = []
all_lambda_full = []
metrics = {}
for layer_idx, layer in enumerate(model.model.layers):
attn = layer.self_attn
# Collect stats for aggregation
all_q1_norms.append(attn.lambda_q1.norm().item())
all_q2_norms.append(attn.lambda_q2.norm().item())
all_k1_norms.append(attn.lambda_k1.norm().item())
all_k2_norms.append(attn.lambda_k2.norm().item())
lambda1 = torch.exp(torch.sum(attn.lambda_q1 * attn.lambda_k1)).item()
lambda2 = torch.exp(torch.sum(attn.lambda_q2 * attn.lambda_k2)).item()
all_lambda1.append(lambda1)
all_lambda2.append(lambda2)
all_lambda_full.append(attn.lambda_full)
# Log detailed metrics for monitored layers
if layer_idx in self.monitor_layers:
metrics.update(
{
f"layer_{layer_idx}/lambda_q1_norm": attn.lambda_q1.norm().item(),
f"layer_{layer_idx}/lambda_k1_norm": attn.lambda_k1.norm().item(),
f"layer_{layer_idx}/lambda_q2_norm": attn.lambda_q2.norm().item(),
f"layer_{layer_idx}/lambda_k2_norm": attn.lambda_k2.norm().item(),
f"layer_{layer_idx}/lambda1": lambda1,
f"layer_{layer_idx}/lambda2": lambda2,
f"layer_{layer_idx}/lambda_init": attn.lambda_init.item(),
f"layer_{layer_idx}/lambda_full": lambda1
- lambda2
+ attn.lambda_init.item(),
f"layer_{layer_idx}/attn1_mean": attn.attn1.mean().item(),
f"layer_{layer_idx}/attn2_mean": attn.attn2.mean().item(),
f"layer_{layer_idx}/attn1_std": attn.attn1.std().item(),
f"layer_{layer_idx}/attn2_std": attn.attn2.std().item(),
}
)
# Add aggregate metrics
metrics.update(
{
"aggregate/lambda_q1_norm_mean": torch.tensor(all_q1_norms)
.mean()
.item(),
"aggregate/lambda_q1_norm_std": torch.tensor(all_q1_norms).std().item(),
"aggregate/lambda_q2_norm_mean": torch.tensor(all_q2_norms)
.mean()
.item(),
"aggregate/lambda_q2_norm_std": torch.tensor(all_q2_norms).std().item(),
"aggregate/lambda_k1_norm_mean": torch.tensor(all_k1_norms)
.mean()
.item(),
"aggregate/lambda_k1_norm_std": torch.tensor(all_k1_norms).std().item(),
"aggregate/lambda_k2_norm_mean": torch.tensor(all_k2_norms)
.mean()
.item(),
"aggregate/lambda_k2_norm_std": torch.tensor(all_k2_norms).std().item(),
"aggregate/lambda1_mean": torch.tensor(all_lambda1).mean().item(),
"aggregate/lambda1_std": torch.tensor(all_lambda1).std().item(),
"aggregate/lambda2_mean": torch.tensor(all_lambda2).mean().item(),
"aggregate/lambda2_std": torch.tensor(all_lambda2).std().item(),
"aggregate/lambda_full_mean": torch.tensor(all_lambda_full)
.mean()
.item(),
"aggregate/lambda_full_std": torch.tensor(all_lambda_full).std().item(),
}
)
if self.warmup_steps:
metrics["aggregate/diff_attn_mix"] = attn.diff_attn_mix
wandb.log(metrics, step=state.global_step)
class DifferentialAttentionMixingCallback(TrainerCallback):
"""
Callback to gradually increase the weight of negative attention components during
training.
"""
def __init__(self, warmup_steps: int):
"""
Args:
warmup_steps: Number of steps to linearly increase negative attention
weight from 0 to 1. If `None`, negative attention has full weight from
start.
"""
self.warmup_steps = warmup_steps
self.diff_attention_layers: list[nn.Module] | None = None
# pylint: disable=unused-argument
def on_train_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module,
**kwargs,
) -> None:
"""Cache the differential attention layers at the start of training."""
if model is not None:
# Get the actual model if it's wrapped
if hasattr(model, "module"):
model = model.module
# Cache all differential attention layers
self.diff_attention_layers = [
module for module in model.modules() if hasattr(module, "diff_attn_mix")
]
def on_step_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module = None,
**kwargs,
) -> None:
if self.diff_attention_layers and self.warmup_steps:
# Calculate mixing parameter (0 to 1)
mix = min(1.0, state.global_step / self.warmup_steps)
# Update cached layers
for layer in self.diff_attention_layers:
layer.diff_attn_mix = mix

View File

@@ -1,31 +0,0 @@
"""Shared fixtures for differential transformer conversion tests."""
import pytest
from click.testing import CliRunner
@pytest.fixture(scope="class")
def base_config():
"""Basic config for testing."""
return {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [
{
"path": "axolotl-ai-co/alpaca_100_test",
"type": "alpaca",
},
],
"gradient_accumulation_steps": 1,
"learning_rate": 1e-4,
"val_set_size": 0.1,
"micro_batch_size": 1,
"sequence_len": 2048,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
}
@pytest.fixture(scope="class")
def cli_runner():
return CliRunner()

View File

@@ -1,51 +0,0 @@
"""End-to-end tests for differential transformer conversion and evaluation."""
# pylint: disable=duplicate-code
from pathlib import Path
import yaml
from pytest import approx
from axolotl.cli import load_cfg
from axolotl.cli.evaluate import do_evaluate
from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs
def test_conversion_and_eval_cli(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
eval_cfg = load_cfg(str(output_dir))
eval_cli_args = EvaluateCliArgs()
all_metrics = do_evaluate(eval_cfg, eval_cli_args)
assert list(all_metrics.keys()) == [
"train_loss",
"train_model_preparation_time",
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
"eval_loss",
"eval_model_preparation_time",
"eval_runtime",
"eval_samples_per_second",
"eval_steps_per_second",
]
assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4)
assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4)

View File

@@ -1,150 +0,0 @@
"""End-to-end tests for differential transformer conversion."""
# pylint: disable=redefined-outer-name
# pylint: disable=duplicate-code
from pathlib import Path
from typing import Optional
from unittest.mock import patch
import pytest
import yaml
from axolotl.cli import load_cfg
from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer
from axolotl.cli.main import cli
from axolotl.common.cli import ConvertDiffTransformerCliArgs
def test_cli_validation(cli_runner):
# Test missing config file
result = cli_runner.invoke(cli, ["convert-diff-transformer"])
assert result.exit_code != 0
assert "Error: Missing argument 'CONFIG'." in result.output
# Test non-existent config file
result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def test_basic_execution(cli_runner, tmp_path: Path, base_config):
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
with patch(
"axolotl.cli.integrations.convert_diff_transformer.do_cli"
) as mock_do_cli:
result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)])
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
def test_conversion_cli_basic(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs()
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert not debug_info
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
def test_conversion_cli_debug(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert not debug_info["generations_match"]
assert not debug_info["match_expected"]
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
def test_conversion_cli_reproduce(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
@pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
)
def test_conversion_cli_repoduce_attentions(
tmp_path: Path, base_config, attention: Optional[str]
):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
@pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
)
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
output_dir = tmp_path / "converted"
# Smallest model with an even number of attention heads
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is False
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()