adding diff attn callback, adding documentation

This commit is contained in:
Dan Saunders
2025-01-10 16:28:27 +00:00
parent 443327c585
commit 4f804f6d88
9 changed files with 676 additions and 101 deletions

View File

@@ -202,7 +202,7 @@ def do_inference(
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)

View File

@@ -87,8 +87,6 @@ def convert_diff_transformer(cfg, cli_args, config_path):
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
init_scale=cli_args.init_scale,
reinit_lambda_init=cli_args.reinit_lambda_init,
)
model = LlamaDifferentialForCausalLM.from_llama(model, config)
model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -54,10 +54,8 @@ class ConvertDiffTransformerCliArgs:
debug: bool = field(default=False)
zero_init: bool = field(default=False)
sublayer_norm: bool = field(default=False)
sublayer_norm: bool = field(default=True)
split_heads: bool = field(default=False)
init_scale: float = field(default=1e-6)
reinit_lambda_init: bool = field(default=True)
def load_model_and_tokenizer(

View File

@@ -1,8 +1,13 @@
"""Definition of differential transformer plugin."""
import logging
from typing import List
from transformers import PreTrainedModel, TrainerCallback
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.differential import DifferentialAttentionMonitorCallback
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
@@ -10,10 +15,41 @@ LOG = logging.getLogger(__name__)
class DifferentialTransformerPlugin(BasePlugin):
"""Plugin for differential transformer integration with Axolotl."""
def __init__(self):
def __init__(self) -> None:
"""
Constructor for differential transformers plugin. Calls `register_diff_attn`
to register differential attention custom modeling implementation to `AutoConfig`
and `AutoModel`.
"""
from .modeling_diff_attn import register_diff_attn
register_diff_attn()
def get_input_args(self):
def get_input_args(self) -> str:
"""Returns module path to diff transformer plugin args for `axolotl` config."""
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> List[TrainerCallback]:
"""
Returns `DifferentialAttentionMonitorCallback` to be added to the list of
callbacks for the `axolotl` trainer if wandb usage is enabled.
Parameters:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The loaded mfodel.
Returns:
A list (possibly) containing an instantiated `DifferentialAttentionMonitorCallback`.
"""
callbacks = []
if cfg.use_wandb:
callbacks.append(
DifferentialAttentionMonitorCallback(
log_every=cfg.diff_attn_log_every,
num_monitor_layers=cfg.diff_attn_num_monitor_layers,
)
)
return callbacks

View File

@@ -12,3 +12,5 @@ class DifferentialTransformerArgs(BaseModel):
"""Input args for differential transformer."""
diff_attention: Optional[bool] = None
diff_attn_log_every: Optional[int] = 100
diff_attn_num_monitor_layers: Optional[int] = 3

View File

@@ -1,9 +1,10 @@
"""Re-implemention of differential attention."""
"""Re-implemention of differential attention from the Differential Transformer paper
(https://arxiv.org/abs/2410.05258)."""
# pylint: disable=invalid-name
import logging
import math
from typing import Any, Optional, Tuple
from typing import Any
import torch
import torch.nn.functional as F
@@ -27,6 +28,18 @@ except ImportError:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Repeats key/value heads to match the number of query heads in multi-head attention.
Args:
x: Input tensor of shape `(batch_size, num_kv_heads, seq_len, head_dim)`.
n_rep: Number of times to repeat each head.
Returns:
Tensor with repeated heads of shape `(batch_size, num_kv_heads * n_rep,
seq_len, head_dim)`.
If `n_rep` is 1, returns the input tensor unchanged.
"""
batch_size, n_kv_heads, slen, head_dim = x.shape
if n_rep == 1:
return x
@@ -37,14 +50,48 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
)
def lambda_init_fn(depth):
def lambda_init_fn(depth: int) -> float:
"""
Lambda mixing parameter init function from the "Differential Transformer" paper.
Args:
depth: Index of layer to init lambda parameter.
Returns:
Lambda initialization value (decreasing with `depth`).
"""
return 0.8 - 0.6 * math.exp(-0.3 * depth)
class LlamaDifferentialAttentionBase(nn.Module):
"""Base class for differential attention implementations."""
"""
Base class for differential attention implementations.
def __init__(self, config: Any, layer_idx: int):
This class implements the core differential attention mechanism used in Llama models.
It supports both split heads and double projection modes for attention computation.
"""
def __init__(self, config: Any, layer_idx: int) -> None:
"""
Initializes the differential attention module.
Args:
config: Model configuration object containing hyperparameters, including:
- hidden_size: The size of hidden states
- num_attention_heads: Number of attention heads
- num_key_value_heads: Number of key/value heads
- attention_bias: Whether to use bias in attention projections
- split_heads: Whether to use split heads mode
- rms_norm_eps: Epsilon for RMS normalization
layer_idx: The index of this layer in the model
Note:
The initialization process consists of four steps:
1. Configuration initialization (`_init_config`)
2. Projection layers initialization (`_init_projections`)
3. Differential parameters initialization (`_init_differential_params`)
4. Normalization layers initialization (`_init_normalization`)
"""
super().__init__()
self.config = config
@@ -53,8 +100,24 @@ class LlamaDifferentialAttentionBase(nn.Module):
self._init_differential_params()
self._init_normalization()
def _init_config(self, layer_idx: int):
"""Initialize configuration parameters."""
# For logging
self.attn1 = None
self.attn2 = None
self.lambda_full = None
def _init_config(self, layer_idx: int) -> None:
"""
Initializes configuration parameters for the attention layer. Sets up various
dimension sizes and head counts based on the provided config. Handles both
split heads and double projection modes.
Args:
layer_idx: Index of the current layer.
Note:
In split heads mode, the number of heads is divided by 2 (rounding down),
which differs from the original implementation that required an even number.
"""
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
self.base_num_heads = self.config.num_attention_heads
self.base_num_kv_heads = self.config.num_key_value_heads
@@ -62,26 +125,26 @@ class LlamaDifferentialAttentionBase(nn.Module):
self.layer_idx = layer_idx
if self.config.split_heads:
# Split heads mode - single projections
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
# implementation, which asserts `self.base_num_heads` is even
self.heads_per_component = self.base_num_heads // 2
self.kv_heads_per_component = self.base_num_kv_heads // 2
self.value_head_dim = 2 * self.head_dim
else:
# Double projection mode
self.heads_per_component = self.base_num_heads
self.kv_heads_per_component = self.base_num_kv_heads
self.value_head_dim = self.head_dim
def _init_projections(self):
"""Initialize Q, K, V projections."""
def _init_projections(self) -> None:
"""
Initializes the query, key, value, and output projection layers.
Creates linear transformations for Q, K, V projections with dimensions
depending on whether split heads or double projection mode is used.
The output projection combines the attention heads back to model dimension.
"""
if self.config.split_heads:
# Split heads mode - single projections
q_out_dim = self.config.hidden_size
k_out_dim = self.head_dim * self.base_num_kv_heads
else:
# Double projection mode
q_out_dim = self.config.hidden_size * 2
k_out_dim = self.head_dim * self.base_num_kv_heads * 2
@@ -102,8 +165,15 @@ class LlamaDifferentialAttentionBase(nn.Module):
bias=self.config.attention_bias,
)
def _init_differential_params(self):
"""Initialize differential attention parameters."""
def _init_differential_params(self) -> None:
"""
Initializes parameters specific to differential attention.
Creates learnable parameters for the differential attention mechanism:
- Lambda parameters for queries and keys
- Initial lambda value based on layer index
- Rotary position embedding layer
"""
self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx)),
requires_grad=False,
@@ -122,19 +192,42 @@ class LlamaDifferentialAttentionBase(nn.Module):
)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
def _init_normalization(self):
"""Initialize normalization layers."""
def _init_normalization(self) -> None:
"""
Initializes normalization layers for the attention mechanism.
Sets up either RMS normalization or identity transformation based on config.
The normalization is applied to the sublayer output if enabled.
"""
sublayer_norm = getattr(self.config, "sublayer_norm", True)
if sublayer_norm:
self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps)
else:
self.subln = nn.Identity()
def _prepare_attention_inputs(self, hidden_states: torch.Tensor):
"""Prepare inputs for attention computation."""
def _prepare_attention_inputs(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepares input tensors for attention computation.
Projects input hidden states to query, key, and value spaces, then reshapes
them for multi-head attention processing.
Args:
hidden_states: Input tensor of shape `(batch_size, seq_len,
hidden_size)`.
Returns:
tuple: Tuple containing:
- q1: Positive attention query component
- q2: Negative attention query component
- k1: Positive attention key component
- k2: Negative attention key component
- v: Value tensor
"""
bsz, q_len, _ = hidden_states.size()
# Project and split
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
@@ -158,9 +251,41 @@ class LlamaDifferentialAttentionBase(nn.Module):
return q1, q2, k1, k2, v
def _apply_rotary_embeddings(
self, q1, q2, k1, k2, position_ids, position_embeddings
):
"""Apply rotary embeddings to queries and keys."""
self,
q1: torch.Tensor,
q2: torch.Tensor,
k1: torch.Tensor,
k2: torch.Tensor,
position_ids: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
Applies rotary positional embeddings to queries and keys.
Args:
q1: Positive attention query component.
q2: Negative attention query component.
k1: Positive attention key component.
k2: Negative attention key component.
position_ids: Token position indices.
position_embeddings: Pre-computed rotary embeddings (cos, sin).
Returns:
tuple: Tuple containing:
- q1: Positive attention query with positional encoding.
- q2: Negative attention query with positional encoding.
- k1: Positive attention key with positional encoding.
- k2: Negative attention key with positional encoding.
- cos: Cosine part of rotary embeddings.
- sin: Sine part of rotary embeddings.
"""
if position_embeddings is None:
LOG.warning(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
@@ -177,14 +302,36 @@ class LlamaDifferentialAttentionBase(nn.Module):
return q1, q2, k1, k2, cos, sin
def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs):
"""Handle caching for autoregressive generation."""
def _handle_cache(
self,
k1: torch.Tensor,
k2: torch.Tensor,
v: torch.Tensor,
past_key_value: Cache | None,
cache_kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Handles key-value caching for autoregressive generation and the repetition of
key-value heads to match the number of query heads.
Args:
k1: Positive attention key component.
k2: Negative attention key component.
v: Value tensor.
past_key_value: Cache object for storing previous key-value pairs.
cache_kwargs: Additional arguments for cache handling.
Returns:
tuple: Tuple containing:
- k1: Processed positive attention key component.
- k2: Processed negative attention key component.
- v: Processed value tensor.
"""
if past_key_value is not None:
k = torch.stack([k1, k2], dim=1)
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1)
# Repeat KV heads to match number of query heads
k1 = repeat_kv(k1, self.num_key_value_groups)
k2 = repeat_kv(k2, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
@@ -193,39 +340,90 @@ class LlamaDifferentialAttentionBase(nn.Module):
return k1, k2, v
def _compute_lambda(self, q1):
"""Compute lambda values for differential attention."""
def _compute_lambda(self, q1: torch.Tensor) -> torch.Tensor:
"""
Computes lambda values for differential attention.
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are
computed from the learned parameters.
Args:
q1: Positive attention query component, used for type casting.
Returns:
Computed lambda value for differential attention.
"""
lambda_1 = torch.exp(
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
).type_as(q1)
lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1)
return lambda_1 - lambda_2 + self.lambda_init
def _process_attention_output(self, attn, bsz, q_len):
"""Process and project attention output."""
def _process_attention_output(
self, attn: torch.Tensor, bsz: int, q_len: int
) -> torch.Tensor:
"""
Processes and projects the attention output. Applies sublayer normalization,
scales by (1 - λ_init), and projects back to model dimension.
Args:
attn: Raw attention output.
bsz: Batch size.
q_len: Query sequence length.
Returns:
Processed attention output of shape (batch_size, seq_len, hidden_size)
"""
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
return self.o_proj(attn)
class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
"""Standard implementation of differential attention."""
"""
Standard implementation of differential attention.
This class implements the standard differential attention mechanism using
explicit matrix multiplications for the attention computation.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using standard matrix multiplication operations.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- Attention weights if output_attentions is True, else None.
- Updated key-value cache if use_cache is True, else None.
"""
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
@@ -255,6 +453,11 @@ class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
if output_attentions:
attn_weights = attn1 - lambda_full * attn2
attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1)
@@ -263,27 +466,53 @@ class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
"""SDPA-based implementation of differential attention."""
"""
SDPA-based implementation of differential attention.
This class implements differential attention using PyTorch's scaled_dot_product_attention
for improved performance on supported hardware.
"""
# pylint: disable=duplicate-code
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using PyTorch's scaled dot product attention.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- None for attention weights (SDPA doesn't support output_attentions).
- Updated key-value cache if use_cache is True, else None.
"""
if output_attentions:
LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but "
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
+ "`output_attentions=True`. Falling back to the eager attention implementation."
)
# pylint: disable=duplicate-code
return LlamaDifferentialAttention.forward(
self,
hidden_states,
@@ -326,15 +555,35 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
return attn, None, past_key_value
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
"""Flash Attention 2-based implementation of differential attention."""
"""
Flash Attention 2-based implementation of differential attention.
This class implements differential attention using Flash Attention 2 for maximum
performance on supported hardware.
"""
def __init__(self, *args, **kwargs):
"""
Initializes the Flash Attention 2 differential attention module.
Args:
*args: Positional arguments passed to parent class.
**kwargs: Keyword arguments passed to parent class.
Raises:
ImportError: If flash-attn library is not installed.
"""
if not FLASH_ATTENTION_AVAILABLE:
raise ImportError(
"LlamaDifferentialFlashAttention2 requires flash-attn library. "
@@ -343,25 +592,46 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
super().__init__(*args, **kwargs)
# pylint: disable=duplicate-code
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using Flash Attention 2.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- None for attention weights (Flash Attention doesn't support output_attentions).
- Updated key-value cache if use_cache is True, else None.
"""
if output_attentions:
LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but "
+ "flash attenion does not support `output_attentions=True`. Falling back "
+ "to the eager attention implementation."
)
# pylint: disable=duplicate-code
return LlamaDifferentialAttention.forward(
self,
hidden_states,
@@ -407,6 +677,11 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
return attn, None, past_key_value

View File

@@ -1,7 +1,11 @@
"""Modeling for differential transformers."""
"""
Modeling for differential transformers.
This module implements differential attention variants of the LLaMA model,
providing various attention implementations for improved performance.
"""
import logging
from typing import Optional, Union
import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
@@ -18,7 +22,12 @@ logger = logging.getLogger(__name__)
class LlamaDifferentialConfig(LlamaConfig):
"""Configuration class for Differential LLaMA model."""
"""
Configuration class for Differential LLaMA model.
Extends the base LLaMA configuration with additional parameters for differential
attention mechanisms.
"""
model_type = "llama-differential"
@@ -29,6 +38,15 @@ class LlamaDifferentialConfig(LlamaConfig):
zero_init: bool = False,
**kwargs,
):
"""
Initialize differential LLaMA configuration.
Args:
split_heads: Whether to use split heads mode for attention computation.
sublayer_norm: Whether to apply normalization to sublayers.
zero_init: Whether to initialize new weights to zero.
**kwargs: Additional arguments passed to LlamaConfig.
"""
super().__init__(**kwargs)
self.split_heads = split_heads
self.sublayer_norm = sublayer_norm
@@ -42,12 +60,26 @@ class LlamaDifferentialConfig(LlamaConfig):
class LlamaDifferentialModel(LlamaModel):
"""LlamaModel with differential attention."""
"""
LlamaModel with differential attention.
This class extends the base LLaMA model by replacing standard attention with
differential attention mechanisms.
"""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def __init__(self, config):
def __init__(self, config: LlamaDifferentialConfig):
"""
Initialize a differential LLaMA model.
Args:
config: Configuration object for the model.
Raises:
ValueError: If specified attention implementation is not supported.
"""
super().__init__(config)
# Handle attention implementation
@@ -76,11 +108,26 @@ class LlamaDifferentialModel(LlamaModel):
for idx, layer in enumerate(self.layers):
layer.self_attn = attn_class(config, idx)
# pylint: disable=protected-access
@classmethod
# pylint: disable=protected-access
def _autoset_attn_implementation(
cls, config, **kwargs
): # pylint: disable=unused-argument
cls,
config: LlamaDifferentialConfig,
**kwargs, # pylint: disable=unused-argument
) -> LlamaDifferentialConfig:
"""
Automatically set the attention implementation based on config.
Args:
config: Model configuration object.
**kwargs: Additional arguments (unused).
Returns:
Updated configuration object.
Raises:
ValueError: If specified attention implementation is not supported.
"""
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)
@@ -110,10 +157,23 @@ class LlamaDifferentialModel(LlamaModel):
@classmethod
def from_llama(
cls,
model: Union[LlamaModel, LlamaForCausalLM],
config: Optional[LlamaDifferentialConfig] = None,
model: LlamaModel | LlamaForCausalLM,
config: LlamaDifferentialConfig | None = None,
) -> "LlamaDifferentialModel":
"""Convert a LlamaModel to use differential attention."""
"""
Convert a `LlamaModel` to use differential attention.
Args:
model: Base LLaMA model to convert.
config: Configuration for differential attention. If `None`, created from
base model config.
Returns:
Converted model with differential attention.
Raises:
ValueError: If number of heads is not even when using `split_heads` mode.
"""
logger.info(f"Converting {type(model).__name__} to {cls.__name__}")
# Handle LlamaForCausalLM
@@ -182,7 +242,6 @@ class LlamaDifferentialModel(LlamaModel):
if config.zero_init:
logger.debug(f"Layer {layer_idx}: Zero initializing")
# Zero out components as needed
with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_()
@@ -192,45 +251,60 @@ class LlamaDifferentialModel(LlamaModel):
new_layer.self_attn.lambda_k2.zero_()
new_layer.self_attn.lambda_init.zero_()
else:
logger.debug(
f"Layer {layer_idx}: Initializing with scale {config.init_scale}"
# Mirror weights for second component
new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_(
old_layer.self_attn.q_proj.weight.data
)
new_layer.self_attn.k_proj.weight.data[old_k_size:].copy_(
old_layer.self_attn.k_proj.weight.data
)
# Initialize with small random values
with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[old_q_size:].normal_(
0, config.init_scale
)
new_layer.self_attn.k_proj.weight.data[old_k_size:].normal_(
0, config.init_scale
)
new_layer.self_attn.lambda_q1.normal_(0, config.init_scale)
new_layer.self_attn.lambda_k1.normal_(0, config.init_scale)
new_layer.self_attn.lambda_q2.normal_(0, config.init_scale)
new_layer.self_attn.lambda_k2.normal_(0, config.init_scale)
if config.reinit_lambda_init:
new_layer.self_attn.lambda_init.normal_(
0, config.init_scale
).abs_()
logger.info("Conversion complete")
return new_model
class LlamaDifferentialForCausalLM(LlamaForCausalLM):
"""LlamaForCausalLM with differential attention."""
"""
`LlamaForCausalLM` with differential attention.
This class extends the base LLaMA causal language model by incorporating
differential attention mechanisms.
"""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def __init__(self, config):
def __init__(self, config: LlamaDifferentialConfig):
"""
Initialize a differential LLaMA model for causal language modeling.
Args:
config: Configuration object for the model.
"""
super().__init__(config)
self.model = LlamaDifferentialModel(config)
# pylint: disable=protected-access
@classmethod
# pylint: disable=protected-access
def _autoset_attn_implementation(
cls, config, **kwargs
): # pylint: disable=unused-argument
cls,
config: LlamaDifferentialConfig,
**kwargs, # pylint: disable=unused-argument
) -> LlamaDifferentialConfig:
"""
Automatically set the attention implementation based on config.
Args:
config: Model configuration object.
**kwargs: Additional arguments (unused).
Returns:
Updated configuration object.
Raises:
ValueError: If specified attention implementation is not supported.
"""
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)
@@ -239,6 +313,7 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
config._attn_implementation = config._attn_implementations[
attn_implementation
]
return config
# If no mapping, validate it's a valid differential type
@@ -259,9 +334,22 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
@classmethod
def from_llama(
cls, model: LlamaForCausalLM, config: Optional[LlamaDifferentialConfig] = None
cls, model: LlamaForCausalLM, config: LlamaDifferentialConfig | None = None
) -> "LlamaDifferentialForCausalLM":
"""Convert a LlamaForCausalLM to use differential attention."""
"""
Convert a `LlamaForCausalLM` to use differential attention.
Args:
model: Base LLaMA model to convert.
config: Configuration for differential attention. If `None`, created from
base model config.
Returns:
Converted model with differential attention.
Raises:
ValueError: If number of heads is not even when using `split_heads` mode.
"""
if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__)
@@ -285,7 +373,14 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
return new_model
def register_diff_attn():
def register_diff_attn() -> None:
"""
Register differential attention components with the transformers library.
This function registers the differential attention configurations and model classes
with the Auto* classes from `transformers`, making them available through the
standard model loading pipeline.
"""
# Register configs
AutoConfig.register("llama-differential", LlamaDifferentialConfig)

View File

@@ -0,0 +1,171 @@
"""
Monitor and log differential attention components during training.
This module provides a callback for tracking the behavior of differential attention
mechanisms, including lambda parameters and attention statistics.
"""
from typing import Any
import torch
import wandb
from transformers import TrainerCallback
from axolotl.utils.distributed import is_main_process
class DifferentialAttentionMonitorCallback(TrainerCallback):
"""
Callback to monitor differential attention components and lambda parameters.
This callback tracks attention statistics across all layers and provides detailed
monitoring for a specified number of layers evenly spaced through the model.
"""
def __init__(self, log_every: int = 250, num_monitor_layers: int = 3):
"""
Initialize the differential attention monitor.
Args:
log_every: Number of steps between logging events.
num_monitor_layers: Number of individual layers to monitor in detail.
"""
self.log_every = log_every
self.num_monitor_layers = num_monitor_layers
self.monitor_layers: list[int] | None = None # Will be set in on_train_begin
# pylint: disable=unused-argument
def on_train_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module,
**kwargs,
) -> None:
"""
Set up layer monitoring at the start of training.
Args:
args: Training arguments.
state: Training state.
control: Training control object.
model: The model being trained.
**kwargs: Additional arguments passed by the trainer.
"""
if is_main_process():
num_layers = len(model.model.layers)
self.num_monitor_layers = min(self.num_monitor_layers, num_layers)
stride = (
(num_layers - 1) / (self.num_monitor_layers - 1)
if self.num_monitor_layers > 1
else 0
)
self.monitor_layers = [
round(i * stride) for i in range(self.num_monitor_layers)
]
print(f"Monitoring layers {self.monitor_layers} in detail")
# pylint: disable=unused-argument
def on_step_end(
self, args: Any, state: Any, control: Any, model: torch.nn.Module, **kwargs
) -> None:
"""
Log attention metrics at the end of each step.
Collects and logs:
- Lambda parameter norms and values.
- Attention statistics (mean and std).
- Both per-layer and aggregate metrics.
Args:
args: Training arguments.
state: Training state.
control: Training control object.
model: The model being trained.
**kwargs: Additional arguments passed by the trainer.
"""
if not is_main_process() or state.global_step % self.log_every != 0:
return
assert self.monitor_layers is not None
# Aggregate stats across all layers
all_q1_norms = []
all_q2_norms = []
all_k1_norms = []
all_k2_norms = []
all_lambda1 = []
all_lambda2 = []
all_lambda_full = []
metrics = {}
for layer_idx, layer in enumerate(model.model.layers):
attn = layer.self_attn
# Collect stats for aggregation
all_q1_norms.append(attn.lambda_q1.norm().item())
all_q2_norms.append(attn.lambda_q2.norm().item())
all_k1_norms.append(attn.lambda_k1.norm().item())
all_k2_norms.append(attn.lambda_k2.norm().item())
lambda1 = torch.exp(torch.sum(attn.lambda_q1 * attn.lambda_k1)).item()
lambda2 = torch.exp(torch.sum(attn.lambda_q2 * attn.lambda_k2)).item()
all_lambda1.append(lambda1)
all_lambda2.append(lambda2)
all_lambda_full.append(attn.lambda_full)
# Log detailed metrics for monitored layers
if layer_idx in self.monitor_layers:
metrics.update(
{
f"layer_{layer_idx}/lambda_q1_norm": attn.lambda_q1.norm().item(),
f"layer_{layer_idx}/lambda_k1_norm": attn.lambda_k1.norm().item(),
f"layer_{layer_idx}/lambda_q2_norm": attn.lambda_q2.norm().item(),
f"layer_{layer_idx}/lambda_k2_norm": attn.lambda_k2.norm().item(),
f"layer_{layer_idx}/lambda1": lambda1,
f"layer_{layer_idx}/lambda2": lambda2,
f"layer_{layer_idx}/lambda_init": attn.lambda_init.item(),
f"layer_{layer_idx}/lambda_full": lambda1
- lambda2
+ attn.lambda_init.item(),
f"layer_{layer_idx}/attn1_mean": attn.attn1.mean().item(),
f"layer_{layer_idx}/attn2_mean": attn.attn2.mean().item(),
f"layer_{layer_idx}/attn1_std": attn.attn1.std().item(),
f"layer_{layer_idx}/attn2_std": attn.attn2.std().item(),
}
)
# Add aggregate metrics
metrics.update(
{
"aggregate/lambda_q1_norm_mean": torch.tensor(all_q1_norms)
.mean()
.item(),
"aggregate/lambda_q1_norm_std": torch.tensor(all_q1_norms).std().item(),
"aggregate/lambda_q2_norm_mean": torch.tensor(all_q2_norms)
.mean()
.item(),
"aggregate/lambda_q2_norm_std": torch.tensor(all_q2_norms).std().item(),
"aggregate/lambda_k1_norm_mean": torch.tensor(all_k1_norms)
.mean()
.item(),
"aggregate/lambda_k1_norm_std": torch.tensor(all_k1_norms).std().item(),
"aggregate/lambda_k2_norm_mean": torch.tensor(all_k2_norms)
.mean()
.item(),
"aggregate/lambda_k2_norm_std": torch.tensor(all_k2_norms).std().item(),
"aggregate/lambda1_mean": torch.tensor(all_lambda1).mean().item(),
"aggregate/lambda1_std": torch.tensor(all_lambda1).std().item(),
"aggregate/lambda2_mean": torch.tensor(all_lambda2).mean().item(),
"aggregate/lambda2_std": torch.tensor(all_lambda2).std().item(),
"aggregate/lambda_full_mean": torch.tensor(all_lambda_full)
.mean()
.item(),
"aggregate/lambda_full_std": torch.tensor(all_lambda_full).std().item(),
}
)
wandb.log(metrics, step=state.global_step)

View File

@@ -69,7 +69,7 @@ def test_conversion_cli_debug(tmp_path: Path, base_config):
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True, init_scale=0.1)
cli_args = ConvertDiffTransformerCliArgs(debug=True)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert not debug_info["generations_match"]