adding diff attn callback, adding documentation

This commit is contained in:
Dan Saunders
2025-01-10 16:28:27 +00:00
parent 443327c585
commit 4f804f6d88
9 changed files with 676 additions and 101 deletions

View File

@@ -202,7 +202,7 @@ def do_inference(
) )
elif cfg.chat_template: elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template) chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template": elif cfg.datasets and cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config( chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
) )

View File

@@ -87,8 +87,6 @@ def convert_diff_transformer(cfg, cli_args, config_path):
zero_init=cli_args.zero_init, zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm, sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads, split_heads=cli_args.split_heads,
init_scale=cli_args.init_scale,
reinit_lambda_init=cli_args.reinit_lambda_init,
) )
model = LlamaDifferentialForCausalLM.from_llama(model, config) model = LlamaDifferentialForCausalLM.from_llama(model, config)
model.to(cfg.device, dtype=cfg.torch_dtype) model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -54,10 +54,8 @@ class ConvertDiffTransformerCliArgs:
debug: bool = field(default=False) debug: bool = field(default=False)
zero_init: bool = field(default=False) zero_init: bool = field(default=False)
sublayer_norm: bool = field(default=False) sublayer_norm: bool = field(default=True)
split_heads: bool = field(default=False) split_heads: bool = field(default=False)
init_scale: float = field(default=1e-6)
reinit_lambda_init: bool = field(default=True)
def load_model_and_tokenizer( def load_model_and_tokenizer(

View File

@@ -1,8 +1,13 @@
"""Definition of differential transformer plugin.""" """Definition of differential transformer plugin."""
import logging import logging
from typing import List
from transformers import PreTrainedModel, TrainerCallback
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.differential import DifferentialAttentionMonitorCallback
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -10,10 +15,41 @@ LOG = logging.getLogger(__name__)
class DifferentialTransformerPlugin(BasePlugin): class DifferentialTransformerPlugin(BasePlugin):
"""Plugin for differential transformer integration with Axolotl.""" """Plugin for differential transformer integration with Axolotl."""
def __init__(self): def __init__(self) -> None:
"""
Constructor for differential transformers plugin. Calls `register_diff_attn`
to register differential attention custom modeling implementation to `AutoConfig`
and `AutoModel`.
"""
from .modeling_diff_attn import register_diff_attn from .modeling_diff_attn import register_diff_attn
register_diff_attn() register_diff_attn()
def get_input_args(self): def get_input_args(self) -> str:
"""Returns module path to diff transformer plugin args for `axolotl` config."""
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs" return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> List[TrainerCallback]:
"""
Returns `DifferentialAttentionMonitorCallback` to be added to the list of
callbacks for the `axolotl` trainer if wandb usage is enabled.
Parameters:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The loaded mfodel.
Returns:
A list (possibly) containing an instantiated `DifferentialAttentionMonitorCallback`.
"""
callbacks = []
if cfg.use_wandb:
callbacks.append(
DifferentialAttentionMonitorCallback(
log_every=cfg.diff_attn_log_every,
num_monitor_layers=cfg.diff_attn_num_monitor_layers,
)
)
return callbacks

View File

@@ -12,3 +12,5 @@ class DifferentialTransformerArgs(BaseModel):
"""Input args for differential transformer.""" """Input args for differential transformer."""
diff_attention: Optional[bool] = None diff_attention: Optional[bool] = None
diff_attn_log_every: Optional[int] = 100
diff_attn_num_monitor_layers: Optional[int] = 3

View File

@@ -1,9 +1,10 @@
"""Re-implemention of differential attention.""" """Re-implemention of differential attention from the Differential Transformer paper
(https://arxiv.org/abs/2410.05258)."""
# pylint: disable=invalid-name # pylint: disable=invalid-name
import logging import logging
import math import math
from typing import Any, Optional, Tuple from typing import Any
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -27,6 +28,18 @@ except ImportError:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Repeats key/value heads to match the number of query heads in multi-head attention.
Args:
x: Input tensor of shape `(batch_size, num_kv_heads, seq_len, head_dim)`.
n_rep: Number of times to repeat each head.
Returns:
Tensor with repeated heads of shape `(batch_size, num_kv_heads * n_rep,
seq_len, head_dim)`.
If `n_rep` is 1, returns the input tensor unchanged.
"""
batch_size, n_kv_heads, slen, head_dim = x.shape batch_size, n_kv_heads, slen, head_dim = x.shape
if n_rep == 1: if n_rep == 1:
return x return x
@@ -37,14 +50,48 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
) )
def lambda_init_fn(depth): def lambda_init_fn(depth: int) -> float:
"""
Lambda mixing parameter init function from the "Differential Transformer" paper.
Args:
depth: Index of layer to init lambda parameter.
Returns:
Lambda initialization value (decreasing with `depth`).
"""
return 0.8 - 0.6 * math.exp(-0.3 * depth) return 0.8 - 0.6 * math.exp(-0.3 * depth)
class LlamaDifferentialAttentionBase(nn.Module): class LlamaDifferentialAttentionBase(nn.Module):
"""Base class for differential attention implementations.""" """
Base class for differential attention implementations.
def __init__(self, config: Any, layer_idx: int): This class implements the core differential attention mechanism used in Llama models.
It supports both split heads and double projection modes for attention computation.
"""
def __init__(self, config: Any, layer_idx: int) -> None:
"""
Initializes the differential attention module.
Args:
config: Model configuration object containing hyperparameters, including:
- hidden_size: The size of hidden states
- num_attention_heads: Number of attention heads
- num_key_value_heads: Number of key/value heads
- attention_bias: Whether to use bias in attention projections
- split_heads: Whether to use split heads mode
- rms_norm_eps: Epsilon for RMS normalization
layer_idx: The index of this layer in the model
Note:
The initialization process consists of four steps:
1. Configuration initialization (`_init_config`)
2. Projection layers initialization (`_init_projections`)
3. Differential parameters initialization (`_init_differential_params`)
4. Normalization layers initialization (`_init_normalization`)
"""
super().__init__() super().__init__()
self.config = config self.config = config
@@ -53,8 +100,24 @@ class LlamaDifferentialAttentionBase(nn.Module):
self._init_differential_params() self._init_differential_params()
self._init_normalization() self._init_normalization()
def _init_config(self, layer_idx: int): # For logging
"""Initialize configuration parameters.""" self.attn1 = None
self.attn2 = None
self.lambda_full = None
def _init_config(self, layer_idx: int) -> None:
"""
Initializes configuration parameters for the attention layer. Sets up various
dimension sizes and head counts based on the provided config. Handles both
split heads and double projection modes.
Args:
layer_idx: Index of the current layer.
Note:
In split heads mode, the number of heads is divided by 2 (rounding down),
which differs from the original implementation that required an even number.
"""
self.head_dim = self.config.hidden_size // self.config.num_attention_heads self.head_dim = self.config.hidden_size // self.config.num_attention_heads
self.base_num_heads = self.config.num_attention_heads self.base_num_heads = self.config.num_attention_heads
self.base_num_kv_heads = self.config.num_key_value_heads self.base_num_kv_heads = self.config.num_key_value_heads
@@ -62,26 +125,26 @@ class LlamaDifferentialAttentionBase(nn.Module):
self.layer_idx = layer_idx self.layer_idx = layer_idx
if self.config.split_heads: if self.config.split_heads:
# Split heads mode - single projections
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
# implementation, which asserts `self.base_num_heads` is even
self.heads_per_component = self.base_num_heads // 2 self.heads_per_component = self.base_num_heads // 2
self.kv_heads_per_component = self.base_num_kv_heads // 2 self.kv_heads_per_component = self.base_num_kv_heads // 2
self.value_head_dim = 2 * self.head_dim self.value_head_dim = 2 * self.head_dim
else: else:
# Double projection mode
self.heads_per_component = self.base_num_heads self.heads_per_component = self.base_num_heads
self.kv_heads_per_component = self.base_num_kv_heads self.kv_heads_per_component = self.base_num_kv_heads
self.value_head_dim = self.head_dim self.value_head_dim = self.head_dim
def _init_projections(self): def _init_projections(self) -> None:
"""Initialize Q, K, V projections.""" """
Initializes the query, key, value, and output projection layers.
Creates linear transformations for Q, K, V projections with dimensions
depending on whether split heads or double projection mode is used.
The output projection combines the attention heads back to model dimension.
"""
if self.config.split_heads: if self.config.split_heads:
# Split heads mode - single projections
q_out_dim = self.config.hidden_size q_out_dim = self.config.hidden_size
k_out_dim = self.head_dim * self.base_num_kv_heads k_out_dim = self.head_dim * self.base_num_kv_heads
else: else:
# Double projection mode
q_out_dim = self.config.hidden_size * 2 q_out_dim = self.config.hidden_size * 2
k_out_dim = self.head_dim * self.base_num_kv_heads * 2 k_out_dim = self.head_dim * self.base_num_kv_heads * 2
@@ -102,8 +165,15 @@ class LlamaDifferentialAttentionBase(nn.Module):
bias=self.config.attention_bias, bias=self.config.attention_bias,
) )
def _init_differential_params(self): def _init_differential_params(self) -> None:
"""Initialize differential attention parameters.""" """
Initializes parameters specific to differential attention.
Creates learnable parameters for the differential attention mechanism:
- Lambda parameters for queries and keys
- Initial lambda value based on layer index
- Rotary position embedding layer
"""
self.lambda_init = nn.Parameter( self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx)), torch.full((), lambda_init_fn(self.layer_idx)),
requires_grad=False, requires_grad=False,
@@ -122,19 +192,42 @@ class LlamaDifferentialAttentionBase(nn.Module):
) )
self.rotary_emb = LlamaRotaryEmbedding(config=self.config) self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
def _init_normalization(self): def _init_normalization(self) -> None:
"""Initialize normalization layers.""" """
Initializes normalization layers for the attention mechanism.
Sets up either RMS normalization or identity transformation based on config.
The normalization is applied to the sublayer output if enabled.
"""
sublayer_norm = getattr(self.config, "sublayer_norm", True) sublayer_norm = getattr(self.config, "sublayer_norm", True)
if sublayer_norm: if sublayer_norm:
self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps) self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps)
else: else:
self.subln = nn.Identity() self.subln = nn.Identity()
def _prepare_attention_inputs(self, hidden_states: torch.Tensor): def _prepare_attention_inputs(
"""Prepare inputs for attention computation.""" self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepares input tensors for attention computation.
Projects input hidden states to query, key, and value spaces, then reshapes
them for multi-head attention processing.
Args:
hidden_states: Input tensor of shape `(batch_size, seq_len,
hidden_size)`.
Returns:
tuple: Tuple containing:
- q1: Positive attention query component
- q2: Negative attention query component
- k1: Positive attention key component
- k2: Negative attention key component
- v: Value tensor
"""
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
# Project and split
q = self.q_proj(hidden_states) q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states) k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states) v = self.v_proj(hidden_states)
@@ -158,9 +251,41 @@ class LlamaDifferentialAttentionBase(nn.Module):
return q1, q2, k1, k2, v return q1, q2, k1, k2, v
def _apply_rotary_embeddings( def _apply_rotary_embeddings(
self, q1, q2, k1, k2, position_ids, position_embeddings self,
): q1: torch.Tensor,
"""Apply rotary embeddings to queries and keys.""" q2: torch.Tensor,
k1: torch.Tensor,
k2: torch.Tensor,
position_ids: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
Applies rotary positional embeddings to queries and keys.
Args:
q1: Positive attention query component.
q2: Negative attention query component.
k1: Positive attention key component.
k2: Negative attention key component.
position_ids: Token position indices.
position_embeddings: Pre-computed rotary embeddings (cos, sin).
Returns:
tuple: Tuple containing:
- q1: Positive attention query with positional encoding.
- q2: Negative attention query with positional encoding.
- k1: Positive attention key with positional encoding.
- k2: Negative attention key with positional encoding.
- cos: Cosine part of rotary embeddings.
- sin: Sine part of rotary embeddings.
"""
if position_embeddings is None: if position_embeddings is None:
LOG.warning( LOG.warning(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally " "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
@@ -177,14 +302,36 @@ class LlamaDifferentialAttentionBase(nn.Module):
return q1, q2, k1, k2, cos, sin return q1, q2, k1, k2, cos, sin
def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs): def _handle_cache(
"""Handle caching for autoregressive generation.""" self,
k1: torch.Tensor,
k2: torch.Tensor,
v: torch.Tensor,
past_key_value: Cache | None,
cache_kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Handles key-value caching for autoregressive generation and the repetition of
key-value heads to match the number of query heads.
Args:
k1: Positive attention key component.
k2: Negative attention key component.
v: Value tensor.
past_key_value: Cache object for storing previous key-value pairs.
cache_kwargs: Additional arguments for cache handling.
Returns:
tuple: Tuple containing:
- k1: Processed positive attention key component.
- k2: Processed negative attention key component.
- v: Processed value tensor.
"""
if past_key_value is not None: if past_key_value is not None:
k = torch.stack([k1, k2], dim=1) k = torch.stack([k1, k2], dim=1)
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1) k1, k2 = k.unbind(dim=1)
# Repeat KV heads to match number of query heads
k1 = repeat_kv(k1, self.num_key_value_groups) k1 = repeat_kv(k1, self.num_key_value_groups)
k2 = repeat_kv(k2, self.num_key_value_groups) k2 = repeat_kv(k2, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups) v = repeat_kv(v, self.num_key_value_groups)
@@ -193,39 +340,90 @@ class LlamaDifferentialAttentionBase(nn.Module):
return k1, k2, v return k1, k2, v
def _compute_lambda(self, q1): def _compute_lambda(self, q1: torch.Tensor) -> torch.Tensor:
"""Compute lambda values for differential attention.""" """
Computes lambda values for differential attention.
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are
computed from the learned parameters.
Args:
q1: Positive attention query component, used for type casting.
Returns:
Computed lambda value for differential attention.
"""
lambda_1 = torch.exp( lambda_1 = torch.exp(
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
).type_as(q1) ).type_as(q1)
lambda_2 = torch.exp( lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1) ).type_as(q1)
return lambda_1 - lambda_2 + self.lambda_init return lambda_1 - lambda_2 + self.lambda_init
def _process_attention_output(self, attn, bsz, q_len): def _process_attention_output(
"""Process and project attention output.""" self, attn: torch.Tensor, bsz: int, q_len: int
) -> torch.Tensor:
"""
Processes and projects the attention output. Applies sublayer normalization,
scales by (1 - λ_init), and projects back to model dimension.
Args:
attn: Raw attention output.
bsz: Batch size.
q_len: Query sequence length.
Returns:
Processed attention output of shape (batch_size, seq_len, hidden_size)
"""
attn = self.subln(attn) attn = self.subln(attn)
attn = attn * (1 - self.lambda_init) attn = attn * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size) attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
return self.o_proj(attn) return self.o_proj(attn)
class LlamaDifferentialAttention(LlamaDifferentialAttentionBase): class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
"""Standard implementation of differential attention.""" """
Standard implementation of differential attention.
This class implements the standard differential attention mechanism using
explicit matrix multiplications for the attention computation.
"""
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: torch.Tensor | None = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: torch.LongTensor | None = None,
past_key_value: Optional[Cache] = None, past_key_value: Cache | None = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None, cache_position: torch.LongTensor | None = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
): ):
"""
Computes differential attention using standard matrix multiplication operations.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- Attention weights if output_attentions is True, else None.
- Updated key-value cache if use_cache is True, else None.
"""
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states) q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings( q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
@@ -255,6 +453,11 @@ class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v) attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
attn = self._process_attention_output(attn, bsz, q_len) attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
if output_attentions: if output_attentions:
attn_weights = attn1 - lambda_full * attn2 attn_weights = attn1 - lambda_full * attn2
attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1) attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1)
@@ -263,27 +466,53 @@ class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase): class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
"""SDPA-based implementation of differential attention.""" """
SDPA-based implementation of differential attention.
This class implements differential attention using PyTorch's scaled_dot_product_attention
for improved performance on supported hardware.
"""
# pylint: disable=duplicate-code
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: torch.Tensor | None = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: torch.LongTensor | None = None,
past_key_value: Optional[Cache] = None, past_key_value: Cache | None = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: torch.LongTensor | None = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
): ):
"""
Computes differential attention using PyTorch's scaled dot product attention.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- None for attention weights (SDPA doesn't support output_attentions).
- Updated key-value cache if use_cache is True, else None.
"""
if output_attentions: if output_attentions:
LOG.warning( LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but " "LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but "
+ "`torch.nn.functional.scaled_dot_product_attention` does not support " + "`torch.nn.functional.scaled_dot_product_attention` does not support "
+ "`output_attentions=True`. Falling back to the eager attention implementation." + "`output_attentions=True`. Falling back to the eager attention implementation."
) )
# pylint: disable=duplicate-code
return LlamaDifferentialAttention.forward( return LlamaDifferentialAttention.forward(
self, self,
hidden_states, hidden_states,
@@ -326,15 +555,35 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
lambda_full = self._compute_lambda(q1) lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2 attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len) attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
return attn, None, past_key_value return attn, None, past_key_value
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase): class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
"""Flash Attention 2-based implementation of differential attention.""" """
Flash Attention 2-based implementation of differential attention.
This class implements differential attention using Flash Attention 2 for maximum
performance on supported hardware.
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""
Initializes the Flash Attention 2 differential attention module.
Args:
*args: Positional arguments passed to parent class.
**kwargs: Keyword arguments passed to parent class.
Raises:
ImportError: If flash-attn library is not installed.
"""
if not FLASH_ATTENTION_AVAILABLE: if not FLASH_ATTENTION_AVAILABLE:
raise ImportError( raise ImportError(
"LlamaDifferentialFlashAttention2 requires flash-attn library. " "LlamaDifferentialFlashAttention2 requires flash-attn library. "
@@ -343,25 +592,46 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# pylint: disable=duplicate-code
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: torch.Tensor | None = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: torch.LongTensor | None = None,
past_key_value: Optional[Cache] = None, past_key_value: Cache | None = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: torch.LongTensor | None = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
): ):
"""
Computes differential attention using Flash Attention 2.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- None for attention weights (Flash Attention doesn't support output_attentions).
- Updated key-value cache if use_cache is True, else None.
"""
if output_attentions: if output_attentions:
LOG.warning( LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but " "LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but "
+ "flash attenion does not support `output_attentions=True`. Falling back " + "flash attenion does not support `output_attentions=True`. Falling back "
+ "to the eager attention implementation." + "to the eager attention implementation."
) )
# pylint: disable=duplicate-code
return LlamaDifferentialAttention.forward( return LlamaDifferentialAttention.forward(
self, self,
hidden_states, hidden_states,
@@ -407,6 +677,11 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
lambda_full = self._compute_lambda(q1) lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2 attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len) attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
return attn, None, past_key_value return attn, None, past_key_value

View File

@@ -1,7 +1,11 @@
"""Modeling for differential transformers.""" """
Modeling for differential transformers.
This module implements differential attention variants of the LLaMA model,
providing various attention implementations for improved performance.
"""
import logging import logging
from typing import Optional, Union
import torch import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
@@ -18,7 +22,12 @@ logger = logging.getLogger(__name__)
class LlamaDifferentialConfig(LlamaConfig): class LlamaDifferentialConfig(LlamaConfig):
"""Configuration class for Differential LLaMA model.""" """
Configuration class for Differential LLaMA model.
Extends the base LLaMA configuration with additional parameters for differential
attention mechanisms.
"""
model_type = "llama-differential" model_type = "llama-differential"
@@ -29,6 +38,15 @@ class LlamaDifferentialConfig(LlamaConfig):
zero_init: bool = False, zero_init: bool = False,
**kwargs, **kwargs,
): ):
"""
Initialize differential LLaMA configuration.
Args:
split_heads: Whether to use split heads mode for attention computation.
sublayer_norm: Whether to apply normalization to sublayers.
zero_init: Whether to initialize new weights to zero.
**kwargs: Additional arguments passed to LlamaConfig.
"""
super().__init__(**kwargs) super().__init__(**kwargs)
self.split_heads = split_heads self.split_heads = split_heads
self.sublayer_norm = sublayer_norm self.sublayer_norm = sublayer_norm
@@ -42,12 +60,26 @@ class LlamaDifferentialConfig(LlamaConfig):
class LlamaDifferentialModel(LlamaModel): class LlamaDifferentialModel(LlamaModel):
"""LlamaModel with differential attention.""" """
LlamaModel with differential attention.
This class extends the base LLaMA model by replacing standard attention with
differential attention mechanisms.
"""
config_class = LlamaDifferentialConfig config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential" base_model_prefix = "llama_differential"
def __init__(self, config): def __init__(self, config: LlamaDifferentialConfig):
"""
Initialize a differential LLaMA model.
Args:
config: Configuration object for the model.
Raises:
ValueError: If specified attention implementation is not supported.
"""
super().__init__(config) super().__init__(config)
# Handle attention implementation # Handle attention implementation
@@ -76,11 +108,26 @@ class LlamaDifferentialModel(LlamaModel):
for idx, layer in enumerate(self.layers): for idx, layer in enumerate(self.layers):
layer.self_attn = attn_class(config, idx) layer.self_attn = attn_class(config, idx)
# pylint: disable=protected-access
@classmethod @classmethod
# pylint: disable=protected-access
def _autoset_attn_implementation( def _autoset_attn_implementation(
cls, config, **kwargs cls,
): # pylint: disable=unused-argument config: LlamaDifferentialConfig,
**kwargs, # pylint: disable=unused-argument
) -> LlamaDifferentialConfig:
"""
Automatically set the attention implementation based on config.
Args:
config: Model configuration object.
**kwargs: Additional arguments (unused).
Returns:
Updated configuration object.
Raises:
ValueError: If specified attention implementation is not supported.
"""
config._attn_implementation_autoset = True config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None) attn_implementation = getattr(config, "_attn_implementation", None)
@@ -110,10 +157,23 @@ class LlamaDifferentialModel(LlamaModel):
@classmethod @classmethod
def from_llama( def from_llama(
cls, cls,
model: Union[LlamaModel, LlamaForCausalLM], model: LlamaModel | LlamaForCausalLM,
config: Optional[LlamaDifferentialConfig] = None, config: LlamaDifferentialConfig | None = None,
) -> "LlamaDifferentialModel": ) -> "LlamaDifferentialModel":
"""Convert a LlamaModel to use differential attention.""" """
Convert a `LlamaModel` to use differential attention.
Args:
model: Base LLaMA model to convert.
config: Configuration for differential attention. If `None`, created from
base model config.
Returns:
Converted model with differential attention.
Raises:
ValueError: If number of heads is not even when using `split_heads` mode.
"""
logger.info(f"Converting {type(model).__name__} to {cls.__name__}") logger.info(f"Converting {type(model).__name__} to {cls.__name__}")
# Handle LlamaForCausalLM # Handle LlamaForCausalLM
@@ -182,7 +242,6 @@ class LlamaDifferentialModel(LlamaModel):
if config.zero_init: if config.zero_init:
logger.debug(f"Layer {layer_idx}: Zero initializing") logger.debug(f"Layer {layer_idx}: Zero initializing")
# Zero out components as needed
with torch.no_grad(): with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_() new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_() new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_()
@@ -192,45 +251,60 @@ class LlamaDifferentialModel(LlamaModel):
new_layer.self_attn.lambda_k2.zero_() new_layer.self_attn.lambda_k2.zero_()
new_layer.self_attn.lambda_init.zero_() new_layer.self_attn.lambda_init.zero_()
else: else:
logger.debug( # Mirror weights for second component
f"Layer {layer_idx}: Initializing with scale {config.init_scale}" new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_(
old_layer.self_attn.q_proj.weight.data
)
new_layer.self_attn.k_proj.weight.data[old_k_size:].copy_(
old_layer.self_attn.k_proj.weight.data
) )
# Initialize with small random values
with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[old_q_size:].normal_(
0, config.init_scale
)
new_layer.self_attn.k_proj.weight.data[old_k_size:].normal_(
0, config.init_scale
)
new_layer.self_attn.lambda_q1.normal_(0, config.init_scale)
new_layer.self_attn.lambda_k1.normal_(0, config.init_scale)
new_layer.self_attn.lambda_q2.normal_(0, config.init_scale)
new_layer.self_attn.lambda_k2.normal_(0, config.init_scale)
if config.reinit_lambda_init:
new_layer.self_attn.lambda_init.normal_(
0, config.init_scale
).abs_()
logger.info("Conversion complete") logger.info("Conversion complete")
return new_model return new_model
class LlamaDifferentialForCausalLM(LlamaForCausalLM): class LlamaDifferentialForCausalLM(LlamaForCausalLM):
"""LlamaForCausalLM with differential attention.""" """
`LlamaForCausalLM` with differential attention.
This class extends the base LLaMA causal language model by incorporating
differential attention mechanisms.
"""
config_class = LlamaDifferentialConfig config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential" base_model_prefix = "llama_differential"
def __init__(self, config): def __init__(self, config: LlamaDifferentialConfig):
"""
Initialize a differential LLaMA model for causal language modeling.
Args:
config: Configuration object for the model.
"""
super().__init__(config) super().__init__(config)
self.model = LlamaDifferentialModel(config) self.model = LlamaDifferentialModel(config)
# pylint: disable=protected-access
@classmethod @classmethod
# pylint: disable=protected-access
def _autoset_attn_implementation( def _autoset_attn_implementation(
cls, config, **kwargs cls,
): # pylint: disable=unused-argument config: LlamaDifferentialConfig,
**kwargs, # pylint: disable=unused-argument
) -> LlamaDifferentialConfig:
"""
Automatically set the attention implementation based on config.
Args:
config: Model configuration object.
**kwargs: Additional arguments (unused).
Returns:
Updated configuration object.
Raises:
ValueError: If specified attention implementation is not supported.
"""
config._attn_implementation_autoset = True config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None) attn_implementation = getattr(config, "_attn_implementation", None)
@@ -239,6 +313,7 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
config._attn_implementation = config._attn_implementations[ config._attn_implementation = config._attn_implementations[
attn_implementation attn_implementation
] ]
return config return config
# If no mapping, validate it's a valid differential type # If no mapping, validate it's a valid differential type
@@ -259,9 +334,22 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
@classmethod @classmethod
def from_llama( def from_llama(
cls, model: LlamaForCausalLM, config: Optional[LlamaDifferentialConfig] = None cls, model: LlamaForCausalLM, config: LlamaDifferentialConfig | None = None
) -> "LlamaDifferentialForCausalLM": ) -> "LlamaDifferentialForCausalLM":
"""Convert a LlamaForCausalLM to use differential attention.""" """
Convert a `LlamaForCausalLM` to use differential attention.
Args:
model: Base LLaMA model to convert.
config: Configuration for differential attention. If `None`, created from
base model config.
Returns:
Converted model with differential attention.
Raises:
ValueError: If number of heads is not even when using `split_heads` mode.
"""
if config is None: if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__) config = LlamaDifferentialConfig(**model.config.__dict__)
@@ -285,7 +373,14 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
return new_model return new_model
def register_diff_attn(): def register_diff_attn() -> None:
"""
Register differential attention components with the transformers library.
This function registers the differential attention configurations and model classes
with the Auto* classes from `transformers`, making them available through the
standard model loading pipeline.
"""
# Register configs # Register configs
AutoConfig.register("llama-differential", LlamaDifferentialConfig) AutoConfig.register("llama-differential", LlamaDifferentialConfig)

View File

@@ -0,0 +1,171 @@
"""
Monitor and log differential attention components during training.
This module provides a callback for tracking the behavior of differential attention
mechanisms, including lambda parameters and attention statistics.
"""
from typing import Any
import torch
import wandb
from transformers import TrainerCallback
from axolotl.utils.distributed import is_main_process
class DifferentialAttentionMonitorCallback(TrainerCallback):
"""
Callback to monitor differential attention components and lambda parameters.
This callback tracks attention statistics across all layers and provides detailed
monitoring for a specified number of layers evenly spaced through the model.
"""
def __init__(self, log_every: int = 250, num_monitor_layers: int = 3):
"""
Initialize the differential attention monitor.
Args:
log_every: Number of steps between logging events.
num_monitor_layers: Number of individual layers to monitor in detail.
"""
self.log_every = log_every
self.num_monitor_layers = num_monitor_layers
self.monitor_layers: list[int] | None = None # Will be set in on_train_begin
# pylint: disable=unused-argument
def on_train_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module,
**kwargs,
) -> None:
"""
Set up layer monitoring at the start of training.
Args:
args: Training arguments.
state: Training state.
control: Training control object.
model: The model being trained.
**kwargs: Additional arguments passed by the trainer.
"""
if is_main_process():
num_layers = len(model.model.layers)
self.num_monitor_layers = min(self.num_monitor_layers, num_layers)
stride = (
(num_layers - 1) / (self.num_monitor_layers - 1)
if self.num_monitor_layers > 1
else 0
)
self.monitor_layers = [
round(i * stride) for i in range(self.num_monitor_layers)
]
print(f"Monitoring layers {self.monitor_layers} in detail")
# pylint: disable=unused-argument
def on_step_end(
self, args: Any, state: Any, control: Any, model: torch.nn.Module, **kwargs
) -> None:
"""
Log attention metrics at the end of each step.
Collects and logs:
- Lambda parameter norms and values.
- Attention statistics (mean and std).
- Both per-layer and aggregate metrics.
Args:
args: Training arguments.
state: Training state.
control: Training control object.
model: The model being trained.
**kwargs: Additional arguments passed by the trainer.
"""
if not is_main_process() or state.global_step % self.log_every != 0:
return
assert self.monitor_layers is not None
# Aggregate stats across all layers
all_q1_norms = []
all_q2_norms = []
all_k1_norms = []
all_k2_norms = []
all_lambda1 = []
all_lambda2 = []
all_lambda_full = []
metrics = {}
for layer_idx, layer in enumerate(model.model.layers):
attn = layer.self_attn
# Collect stats for aggregation
all_q1_norms.append(attn.lambda_q1.norm().item())
all_q2_norms.append(attn.lambda_q2.norm().item())
all_k1_norms.append(attn.lambda_k1.norm().item())
all_k2_norms.append(attn.lambda_k2.norm().item())
lambda1 = torch.exp(torch.sum(attn.lambda_q1 * attn.lambda_k1)).item()
lambda2 = torch.exp(torch.sum(attn.lambda_q2 * attn.lambda_k2)).item()
all_lambda1.append(lambda1)
all_lambda2.append(lambda2)
all_lambda_full.append(attn.lambda_full)
# Log detailed metrics for monitored layers
if layer_idx in self.monitor_layers:
metrics.update(
{
f"layer_{layer_idx}/lambda_q1_norm": attn.lambda_q1.norm().item(),
f"layer_{layer_idx}/lambda_k1_norm": attn.lambda_k1.norm().item(),
f"layer_{layer_idx}/lambda_q2_norm": attn.lambda_q2.norm().item(),
f"layer_{layer_idx}/lambda_k2_norm": attn.lambda_k2.norm().item(),
f"layer_{layer_idx}/lambda1": lambda1,
f"layer_{layer_idx}/lambda2": lambda2,
f"layer_{layer_idx}/lambda_init": attn.lambda_init.item(),
f"layer_{layer_idx}/lambda_full": lambda1
- lambda2
+ attn.lambda_init.item(),
f"layer_{layer_idx}/attn1_mean": attn.attn1.mean().item(),
f"layer_{layer_idx}/attn2_mean": attn.attn2.mean().item(),
f"layer_{layer_idx}/attn1_std": attn.attn1.std().item(),
f"layer_{layer_idx}/attn2_std": attn.attn2.std().item(),
}
)
# Add aggregate metrics
metrics.update(
{
"aggregate/lambda_q1_norm_mean": torch.tensor(all_q1_norms)
.mean()
.item(),
"aggregate/lambda_q1_norm_std": torch.tensor(all_q1_norms).std().item(),
"aggregate/lambda_q2_norm_mean": torch.tensor(all_q2_norms)
.mean()
.item(),
"aggregate/lambda_q2_norm_std": torch.tensor(all_q2_norms).std().item(),
"aggregate/lambda_k1_norm_mean": torch.tensor(all_k1_norms)
.mean()
.item(),
"aggregate/lambda_k1_norm_std": torch.tensor(all_k1_norms).std().item(),
"aggregate/lambda_k2_norm_mean": torch.tensor(all_k2_norms)
.mean()
.item(),
"aggregate/lambda_k2_norm_std": torch.tensor(all_k2_norms).std().item(),
"aggregate/lambda1_mean": torch.tensor(all_lambda1).mean().item(),
"aggregate/lambda1_std": torch.tensor(all_lambda1).std().item(),
"aggregate/lambda2_mean": torch.tensor(all_lambda2).mean().item(),
"aggregate/lambda2_std": torch.tensor(all_lambda2).std().item(),
"aggregate/lambda_full_mean": torch.tensor(all_lambda_full)
.mean()
.item(),
"aggregate/lambda_full_std": torch.tensor(all_lambda_full).std().item(),
}
)
wandb.log(metrics, step=state.global_step)

View File

@@ -69,7 +69,7 @@ def test_conversion_cli_debug(tmp_path: Path, base_config):
yaml.dump(base_config, file) yaml.dump(base_config, file)
cfg = load_cfg(str(config_path)) cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True, init_scale=0.1) cli_args = ConvertDiffTransformerCliArgs(debug=True)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert not debug_info["generations_match"] assert not debug_info["generations_match"]