diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index d07b10ce3..278a67474 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -202,7 +202,7 @@ def do_inference( ) elif cfg.chat_template: chat_template_str = get_chat_template(cfg.chat_template) - elif cfg.datasets[0].type == "chat_template": + elif cfg.datasets and cfg.datasets[0].type == "chat_template": chat_template_str = get_chat_template_from_config( cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer ) diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index 03126f3bf..ecde82251 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -87,8 +87,6 @@ def convert_diff_transformer(cfg, cli_args, config_path): zero_init=cli_args.zero_init, sublayer_norm=cli_args.sublayer_norm, split_heads=cli_args.split_heads, - init_scale=cli_args.init_scale, - reinit_lambda_init=cli_args.reinit_lambda_init, ) model = LlamaDifferentialForCausalLM.from_llama(model, config) model.to(cfg.device, dtype=cfg.torch_dtype) diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 8f0d6bb77..ebe098ca6 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -54,10 +54,8 @@ class ConvertDiffTransformerCliArgs: debug: bool = field(default=False) zero_init: bool = field(default=False) - sublayer_norm: bool = field(default=False) + sublayer_norm: bool = field(default=True) split_heads: bool = field(default=False) - init_scale: float = field(default=1e-6) - reinit_lambda_init: bool = field(default=True) def load_model_and_tokenizer( diff --git a/src/axolotl/integrations/diff_transformer/__init__.py b/src/axolotl/integrations/diff_transformer/__init__.py index 1dbae22c4..e0ebec818 100644 --- a/src/axolotl/integrations/diff_transformer/__init__.py +++ b/src/axolotl/integrations/diff_transformer/__init__.py @@ -1,8 +1,13 @@ """Definition of differential transformer plugin.""" import logging +from typing import List + +from transformers import PreTrainedModel, TrainerCallback from axolotl.integrations.base import BasePlugin +from axolotl.utils.callbacks.differential import DifferentialAttentionMonitorCallback +from axolotl.utils.dict import DictDefault LOG = logging.getLogger(__name__) @@ -10,10 +15,41 @@ LOG = logging.getLogger(__name__) class DifferentialTransformerPlugin(BasePlugin): """Plugin for differential transformer integration with Axolotl.""" - def __init__(self): + def __init__(self) -> None: + """ + Constructor for differential transformers plugin. Calls `register_diff_attn` + to register differential attention custom modeling implementation to `AutoConfig` + and `AutoModel`. + """ from .modeling_diff_attn import register_diff_attn register_diff_attn() - def get_input_args(self): + def get_input_args(self) -> str: + """Returns module path to diff transformer plugin args for `axolotl` config.""" return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs" + + def add_callbacks_pre_trainer( + self, cfg: DictDefault, model: PreTrainedModel + ) -> List[TrainerCallback]: + """ + Returns `DifferentialAttentionMonitorCallback` to be added to the list of + callbacks for the `axolotl` trainer if wandb usage is enabled. + + Parameters: + cfg: Dictionary mapping `axolotl` config keys to values. + model: The loaded mfodel. + + Returns: + A list (possibly) containing an instantiated `DifferentialAttentionMonitorCallback`. + """ + callbacks = [] + if cfg.use_wandb: + callbacks.append( + DifferentialAttentionMonitorCallback( + log_every=cfg.diff_attn_log_every, + num_monitor_layers=cfg.diff_attn_num_monitor_layers, + ) + ) + + return callbacks diff --git a/src/axolotl/integrations/diff_transformer/args.py b/src/axolotl/integrations/diff_transformer/args.py index 47c1fe110..fe3e7d977 100644 --- a/src/axolotl/integrations/diff_transformer/args.py +++ b/src/axolotl/integrations/diff_transformer/args.py @@ -12,3 +12,5 @@ class DifferentialTransformerArgs(BaseModel): """Input args for differential transformer.""" diff_attention: Optional[bool] = None + diff_attn_log_every: Optional[int] = 100 + diff_attn_num_monitor_layers: Optional[int] = 3 diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py index a03a3fb00..cc3e8f90c 100644 --- a/src/axolotl/integrations/diff_transformer/diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -1,9 +1,10 @@ -"""Re-implemention of differential attention.""" +"""Re-implemention of differential attention from the Differential Transformer paper +(https://arxiv.org/abs/2410.05258).""" # pylint: disable=invalid-name import logging import math -from typing import Any, Optional, Tuple +from typing import Any import torch import torch.nn.functional as F @@ -27,6 +28,18 @@ except ImportError: def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + Repeats key/value heads to match the number of query heads in multi-head attention. + + Args: + x: Input tensor of shape `(batch_size, num_kv_heads, seq_len, head_dim)`. + n_rep: Number of times to repeat each head. + + Returns: + Tensor with repeated heads of shape `(batch_size, num_kv_heads * n_rep, + seq_len, head_dim)`. + If `n_rep` is 1, returns the input tensor unchanged. + """ batch_size, n_kv_heads, slen, head_dim = x.shape if n_rep == 1: return x @@ -37,14 +50,48 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: ) -def lambda_init_fn(depth): +def lambda_init_fn(depth: int) -> float: + """ + Lambda mixing parameter init function from the "Differential Transformer" paper. + + Args: + depth: Index of layer to init lambda parameter. + + Returns: + Lambda initialization value (decreasing with `depth`). + """ return 0.8 - 0.6 * math.exp(-0.3 * depth) class LlamaDifferentialAttentionBase(nn.Module): - """Base class for differential attention implementations.""" + """ + Base class for differential attention implementations. - def __init__(self, config: Any, layer_idx: int): + This class implements the core differential attention mechanism used in Llama models. + It supports both split heads and double projection modes for attention computation. + """ + + def __init__(self, config: Any, layer_idx: int) -> None: + """ + Initializes the differential attention module. + + Args: + config: Model configuration object containing hyperparameters, including: + - hidden_size: The size of hidden states + - num_attention_heads: Number of attention heads + - num_key_value_heads: Number of key/value heads + - attention_bias: Whether to use bias in attention projections + - split_heads: Whether to use split heads mode + - rms_norm_eps: Epsilon for RMS normalization + layer_idx: The index of this layer in the model + + Note: + The initialization process consists of four steps: + 1. Configuration initialization (`_init_config`) + 2. Projection layers initialization (`_init_projections`) + 3. Differential parameters initialization (`_init_differential_params`) + 4. Normalization layers initialization (`_init_normalization`) + """ super().__init__() self.config = config @@ -53,8 +100,24 @@ class LlamaDifferentialAttentionBase(nn.Module): self._init_differential_params() self._init_normalization() - def _init_config(self, layer_idx: int): - """Initialize configuration parameters.""" + # For logging + self.attn1 = None + self.attn2 = None + self.lambda_full = None + + def _init_config(self, layer_idx: int) -> None: + """ + Initializes configuration parameters for the attention layer. Sets up various + dimension sizes and head counts based on the provided config. Handles both + split heads and double projection modes. + + Args: + layer_idx: Index of the current layer. + + Note: + In split heads mode, the number of heads is divided by 2 (rounding down), + which differs from the original implementation that required an even number. + """ self.head_dim = self.config.hidden_size // self.config.num_attention_heads self.base_num_heads = self.config.num_attention_heads self.base_num_kv_heads = self.config.num_key_value_heads @@ -62,26 +125,26 @@ class LlamaDifferentialAttentionBase(nn.Module): self.layer_idx = layer_idx if self.config.split_heads: - # Split heads mode - single projections - # NOTE: This rounds down `base_num_heads / 2` as opposed to the original - # implementation, which asserts `self.base_num_heads` is even self.heads_per_component = self.base_num_heads // 2 self.kv_heads_per_component = self.base_num_kv_heads // 2 self.value_head_dim = 2 * self.head_dim else: - # Double projection mode self.heads_per_component = self.base_num_heads self.kv_heads_per_component = self.base_num_kv_heads self.value_head_dim = self.head_dim - def _init_projections(self): - """Initialize Q, K, V projections.""" + def _init_projections(self) -> None: + """ + Initializes the query, key, value, and output projection layers. + + Creates linear transformations for Q, K, V projections with dimensions + depending on whether split heads or double projection mode is used. + The output projection combines the attention heads back to model dimension. + """ if self.config.split_heads: - # Split heads mode - single projections q_out_dim = self.config.hidden_size k_out_dim = self.head_dim * self.base_num_kv_heads else: - # Double projection mode q_out_dim = self.config.hidden_size * 2 k_out_dim = self.head_dim * self.base_num_kv_heads * 2 @@ -102,8 +165,15 @@ class LlamaDifferentialAttentionBase(nn.Module): bias=self.config.attention_bias, ) - def _init_differential_params(self): - """Initialize differential attention parameters.""" + def _init_differential_params(self) -> None: + """ + Initializes parameters specific to differential attention. + + Creates learnable parameters for the differential attention mechanism: + - Lambda parameters for queries and keys + - Initial lambda value based on layer index + - Rotary position embedding layer + """ self.lambda_init = nn.Parameter( torch.full((), lambda_init_fn(self.layer_idx)), requires_grad=False, @@ -122,19 +192,42 @@ class LlamaDifferentialAttentionBase(nn.Module): ) self.rotary_emb = LlamaRotaryEmbedding(config=self.config) - def _init_normalization(self): - """Initialize normalization layers.""" + def _init_normalization(self) -> None: + """ + Initializes normalization layers for the attention mechanism. + + Sets up either RMS normalization or identity transformation based on config. + The normalization is applied to the sublayer output if enabled. + """ sublayer_norm = getattr(self.config, "sublayer_norm", True) if sublayer_norm: self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps) else: self.subln = nn.Identity() - def _prepare_attention_inputs(self, hidden_states: torch.Tensor): - """Prepare inputs for attention computation.""" + def _prepare_attention_inputs( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Prepares input tensors for attention computation. + + Projects input hidden states to query, key, and value spaces, then reshapes + them for multi-head attention processing. + + Args: + hidden_states: Input tensor of shape `(batch_size, seq_len, + hidden_size)`. + + Returns: + tuple: Tuple containing: + - q1: Positive attention query component + - q2: Negative attention query component + - k1: Positive attention key component + - k2: Negative attention key component + - v: Value tensor + """ bsz, q_len, _ = hidden_states.size() - # Project and split q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) @@ -158,9 +251,41 @@ class LlamaDifferentialAttentionBase(nn.Module): return q1, q2, k1, k2, v def _apply_rotary_embeddings( - self, q1, q2, k1, k2, position_ids, position_embeddings - ): - """Apply rotary embeddings to queries and keys.""" + self, + q1: torch.Tensor, + q2: torch.Tensor, + k1: torch.Tensor, + k2: torch.Tensor, + position_ids: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """ + Applies rotary positional embeddings to queries and keys. + + Args: + q1: Positive attention query component. + q2: Negative attention query component. + k1: Positive attention key component. + k2: Negative attention key component. + position_ids: Token position indices. + position_embeddings: Pre-computed rotary embeddings (cos, sin). + + Returns: + tuple: Tuple containing: + - q1: Positive attention query with positional encoding. + - q2: Negative attention query with positional encoding. + - k1: Positive attention key with positional encoding. + - k2: Negative attention key with positional encoding. + - cos: Cosine part of rotary embeddings. + - sin: Sine part of rotary embeddings. + """ if position_embeddings is None: LOG.warning( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -177,14 +302,36 @@ class LlamaDifferentialAttentionBase(nn.Module): return q1, q2, k1, k2, cos, sin - def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs): - """Handle caching for autoregressive generation.""" + def _handle_cache( + self, + k1: torch.Tensor, + k2: torch.Tensor, + v: torch.Tensor, + past_key_value: Cache | None, + cache_kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Handles key-value caching for autoregressive generation and the repetition of + key-value heads to match the number of query heads. + + Args: + k1: Positive attention key component. + k2: Negative attention key component. + v: Value tensor. + past_key_value: Cache object for storing previous key-value pairs. + cache_kwargs: Additional arguments for cache handling. + + Returns: + tuple: Tuple containing: + - k1: Processed positive attention key component. + - k2: Processed negative attention key component. + - v: Processed value tensor. + """ if past_key_value is not None: k = torch.stack([k1, k2], dim=1) k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) k1, k2 = k.unbind(dim=1) - # Repeat KV heads to match number of query heads k1 = repeat_kv(k1, self.num_key_value_groups) k2 = repeat_kv(k2, self.num_key_value_groups) v = repeat_kv(v, self.num_key_value_groups) @@ -193,39 +340,90 @@ class LlamaDifferentialAttentionBase(nn.Module): return k1, k2, v - def _compute_lambda(self, q1): - """Compute lambda values for differential attention.""" + def _compute_lambda(self, q1: torch.Tensor) -> torch.Tensor: + """ + Computes lambda values for differential attention. + + The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are + computed from the learned parameters. + + Args: + q1: Positive attention query component, used for type casting. + + Returns: + Computed lambda value for differential attention. + """ lambda_1 = torch.exp( torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() ).type_as(q1) lambda_2 = torch.exp( torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() ).type_as(q1) + return lambda_1 - lambda_2 + self.lambda_init - def _process_attention_output(self, attn, bsz, q_len): - """Process and project attention output.""" + def _process_attention_output( + self, attn: torch.Tensor, bsz: int, q_len: int + ) -> torch.Tensor: + """ + Processes and projects the attention output. Applies sublayer normalization, + scales by (1 - λ_init), and projects back to model dimension. + + Args: + attn: Raw attention output. + bsz: Batch size. + q_len: Query sequence length. + + Returns: + Processed attention output of shape (batch_size, seq_len, hidden_size) + """ attn = self.subln(attn) attn = attn * (1 - self.lambda_init) attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size) + return self.o_proj(attn) class LlamaDifferentialAttention(LlamaDifferentialAttentionBase): - """Standard implementation of differential attention.""" + """ + Standard implementation of differential attention. + + This class implements the standard differential attention mechanism using + explicit matrix multiplications for the attention computation. + """ def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, # pylint: disable=unused-argument - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs, # pylint: disable=unused-argument ): + """ + Computes differential attention using standard matrix multiplication operations. + + Args: + hidden_states: Input tensor containing sequence to attend to. + attention_mask: Mask to avoid attention on padding tokens. + position_ids: Indices of positions for positional embeddings. + past_key_value: Cached key and value tensors for autoregressive decoding. + output_attentions: Whether to return attention weights. + use_cache: Whether to use cached key/value states. + cache_position: Position indices for cached states. + position_embeddings: Pre-computed positional embeddings. + **kwargs: Additional arguments passed to the forward call. + + Returns: + tuple containing: + - Output tensor after attention computation. + - Attention weights if output_attentions is True, else None. + - Updated key-value cache if use_cache is True, else None. + """ bsz, q_len, _ = hidden_states.size() q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states) q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings( @@ -255,6 +453,11 @@ class LlamaDifferentialAttention(LlamaDifferentialAttentionBase): attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v) attn = self._process_attention_output(attn, bsz, q_len) + # Save for logging + self.attn1 = attn1 + self.attn2 = attn2 + self.lambda_full = lambda_full + if output_attentions: attn_weights = attn1 - lambda_full * attn2 attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1) @@ -263,27 +466,53 @@ class LlamaDifferentialAttention(LlamaDifferentialAttentionBase): class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase): - """SDPA-based implementation of differential attention.""" + """ + SDPA-based implementation of differential attention. + + This class implements differential attention using PyTorch's scaled_dot_product_attention + for improved performance on supported hardware. + """ - # pylint: disable=duplicate-code def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs, # pylint: disable=unused-argument ): + """ + Computes differential attention using PyTorch's scaled dot product attention. + + Args: + hidden_states: Input tensor containing sequence to attend to. + attention_mask: Mask to avoid attention on padding tokens. + position_ids: Indices of positions for positional embeddings. + past_key_value: Cached key and value tensors for autoregressive decoding. + output_attentions: Whether to return attention weights. + use_cache: Whether to use cached key/value states. + cache_position: Position indices for cached states. + position_embeddings: Pre-computed positional embeddings. + **kwargs: Additional arguments passed to the forward call. + + Returns: + tuple containing: + - Output tensor after attention computation. + - None for attention weights (SDPA doesn't support output_attentions). + - Updated key-value cache if use_cache is True, else None. + """ if output_attentions: LOG.warning( "LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but " + "`torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True`. Falling back to the eager attention implementation." ) + + # pylint: disable=duplicate-code return LlamaDifferentialAttention.forward( self, hidden_states, @@ -326,15 +555,35 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase): lambda_full = self._compute_lambda(q1) attn = attn1 - lambda_full * attn2 - attn = self._process_attention_output(attn, bsz, q_len) + + # Save for logging + self.attn1 = attn1 + self.attn2 = attn2 + self.lambda_full = lambda_full + return attn, None, past_key_value class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase): - """Flash Attention 2-based implementation of differential attention.""" + """ + Flash Attention 2-based implementation of differential attention. + + This class implements differential attention using Flash Attention 2 for maximum + performance on supported hardware. + """ def __init__(self, *args, **kwargs): + """ + Initializes the Flash Attention 2 differential attention module. + + Args: + *args: Positional arguments passed to parent class. + **kwargs: Keyword arguments passed to parent class. + + Raises: + ImportError: If flash-attn library is not installed. + """ if not FLASH_ATTENTION_AVAILABLE: raise ImportError( "LlamaDifferentialFlashAttention2 requires flash-attn library. " @@ -343,25 +592,46 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase): super().__init__(*args, **kwargs) - # pylint: disable=duplicate-code def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs, # pylint: disable=unused-argument ): + """ + Computes differential attention using Flash Attention 2. + + Args: + hidden_states: Input tensor containing sequence to attend to. + attention_mask: Mask to avoid attention on padding tokens. + position_ids: Indices of positions for positional embeddings. + past_key_value: Cached key and value tensors for autoregressive decoding. + output_attentions: Whether to return attention weights. + use_cache: Whether to use cached key/value states. + cache_position: Position indices for cached states. + position_embeddings: Pre-computed positional embeddings. + **kwargs: Additional arguments passed to the forward call. + + Returns: + tuple containing: + - Output tensor after attention computation. + - None for attention weights (Flash Attention doesn't support output_attentions). + - Updated key-value cache if use_cache is True, else None. + """ if output_attentions: LOG.warning( "LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but " + "flash attenion does not support `output_attentions=True`. Falling back " + "to the eager attention implementation." ) + + # pylint: disable=duplicate-code return LlamaDifferentialAttention.forward( self, hidden_states, @@ -407,6 +677,11 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase): lambda_full = self._compute_lambda(q1) attn = attn1 - lambda_full * attn2 - attn = self._process_attention_output(attn, bsz, q_len) + + # Save for logging + self.attn1 = attn1 + self.attn2 = attn2 + self.lambda_full = lambda_full + return attn, None, past_key_value diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py index e0f4eebc1..c8a663cb3 100644 --- a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -1,7 +1,11 @@ -"""Modeling for differential transformers.""" +""" +Modeling for differential transformers. + +This module implements differential attention variants of the LLaMA model, +providing various attention implementations for improved performance. +""" import logging -from typing import Optional, Union import torch from transformers import AutoConfig, AutoModel, AutoModelForCausalLM @@ -18,7 +22,12 @@ logger = logging.getLogger(__name__) class LlamaDifferentialConfig(LlamaConfig): - """Configuration class for Differential LLaMA model.""" + """ + Configuration class for Differential LLaMA model. + + Extends the base LLaMA configuration with additional parameters for differential + attention mechanisms. + """ model_type = "llama-differential" @@ -29,6 +38,15 @@ class LlamaDifferentialConfig(LlamaConfig): zero_init: bool = False, **kwargs, ): + """ + Initialize differential LLaMA configuration. + + Args: + split_heads: Whether to use split heads mode for attention computation. + sublayer_norm: Whether to apply normalization to sublayers. + zero_init: Whether to initialize new weights to zero. + **kwargs: Additional arguments passed to LlamaConfig. + """ super().__init__(**kwargs) self.split_heads = split_heads self.sublayer_norm = sublayer_norm @@ -42,12 +60,26 @@ class LlamaDifferentialConfig(LlamaConfig): class LlamaDifferentialModel(LlamaModel): - """LlamaModel with differential attention.""" + """ + LlamaModel with differential attention. + + This class extends the base LLaMA model by replacing standard attention with + differential attention mechanisms. + """ config_class = LlamaDifferentialConfig base_model_prefix = "llama_differential" - def __init__(self, config): + def __init__(self, config: LlamaDifferentialConfig): + """ + Initialize a differential LLaMA model. + + Args: + config: Configuration object for the model. + + Raises: + ValueError: If specified attention implementation is not supported. + """ super().__init__(config) # Handle attention implementation @@ -76,11 +108,26 @@ class LlamaDifferentialModel(LlamaModel): for idx, layer in enumerate(self.layers): layer.self_attn = attn_class(config, idx) - # pylint: disable=protected-access @classmethod + # pylint: disable=protected-access def _autoset_attn_implementation( - cls, config, **kwargs - ): # pylint: disable=unused-argument + cls, + config: LlamaDifferentialConfig, + **kwargs, # pylint: disable=unused-argument + ) -> LlamaDifferentialConfig: + """ + Automatically set the attention implementation based on config. + + Args: + config: Model configuration object. + **kwargs: Additional arguments (unused). + + Returns: + Updated configuration object. + + Raises: + ValueError: If specified attention implementation is not supported. + """ config._attn_implementation_autoset = True attn_implementation = getattr(config, "_attn_implementation", None) @@ -110,10 +157,23 @@ class LlamaDifferentialModel(LlamaModel): @classmethod def from_llama( cls, - model: Union[LlamaModel, LlamaForCausalLM], - config: Optional[LlamaDifferentialConfig] = None, + model: LlamaModel | LlamaForCausalLM, + config: LlamaDifferentialConfig | None = None, ) -> "LlamaDifferentialModel": - """Convert a LlamaModel to use differential attention.""" + """ + Convert a `LlamaModel` to use differential attention. + + Args: + model: Base LLaMA model to convert. + config: Configuration for differential attention. If `None`, created from + base model config. + + Returns: + Converted model with differential attention. + + Raises: + ValueError: If number of heads is not even when using `split_heads` mode. + """ logger.info(f"Converting {type(model).__name__} to {cls.__name__}") # Handle LlamaForCausalLM @@ -182,7 +242,6 @@ class LlamaDifferentialModel(LlamaModel): if config.zero_init: logger.debug(f"Layer {layer_idx}: Zero initializing") - # Zero out components as needed with torch.no_grad(): new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_() new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_() @@ -192,45 +251,60 @@ class LlamaDifferentialModel(LlamaModel): new_layer.self_attn.lambda_k2.zero_() new_layer.self_attn.lambda_init.zero_() else: - logger.debug( - f"Layer {layer_idx}: Initializing with scale {config.init_scale}" + # Mirror weights for second component + new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_( + old_layer.self_attn.q_proj.weight.data + ) + new_layer.self_attn.k_proj.weight.data[old_k_size:].copy_( + old_layer.self_attn.k_proj.weight.data ) - # Initialize with small random values - with torch.no_grad(): - new_layer.self_attn.q_proj.weight.data[old_q_size:].normal_( - 0, config.init_scale - ) - new_layer.self_attn.k_proj.weight.data[old_k_size:].normal_( - 0, config.init_scale - ) - new_layer.self_attn.lambda_q1.normal_(0, config.init_scale) - new_layer.self_attn.lambda_k1.normal_(0, config.init_scale) - new_layer.self_attn.lambda_q2.normal_(0, config.init_scale) - new_layer.self_attn.lambda_k2.normal_(0, config.init_scale) - if config.reinit_lambda_init: - new_layer.self_attn.lambda_init.normal_( - 0, config.init_scale - ).abs_() logger.info("Conversion complete") + return new_model class LlamaDifferentialForCausalLM(LlamaForCausalLM): - """LlamaForCausalLM with differential attention.""" + """ + `LlamaForCausalLM` with differential attention. + + This class extends the base LLaMA causal language model by incorporating + differential attention mechanisms. + """ config_class = LlamaDifferentialConfig base_model_prefix = "llama_differential" - def __init__(self, config): + def __init__(self, config: LlamaDifferentialConfig): + """ + Initialize a differential LLaMA model for causal language modeling. + + Args: + config: Configuration object for the model. + """ super().__init__(config) self.model = LlamaDifferentialModel(config) - # pylint: disable=protected-access @classmethod + # pylint: disable=protected-access def _autoset_attn_implementation( - cls, config, **kwargs - ): # pylint: disable=unused-argument + cls, + config: LlamaDifferentialConfig, + **kwargs, # pylint: disable=unused-argument + ) -> LlamaDifferentialConfig: + """ + Automatically set the attention implementation based on config. + + Args: + config: Model configuration object. + **kwargs: Additional arguments (unused). + + Returns: + Updated configuration object. + + Raises: + ValueError: If specified attention implementation is not supported. + """ config._attn_implementation_autoset = True attn_implementation = getattr(config, "_attn_implementation", None) @@ -239,6 +313,7 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM): config._attn_implementation = config._attn_implementations[ attn_implementation ] + return config # If no mapping, validate it's a valid differential type @@ -259,9 +334,22 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM): @classmethod def from_llama( - cls, model: LlamaForCausalLM, config: Optional[LlamaDifferentialConfig] = None + cls, model: LlamaForCausalLM, config: LlamaDifferentialConfig | None = None ) -> "LlamaDifferentialForCausalLM": - """Convert a LlamaForCausalLM to use differential attention.""" + """ + Convert a `LlamaForCausalLM` to use differential attention. + + Args: + model: Base LLaMA model to convert. + config: Configuration for differential attention. If `None`, created from + base model config. + + Returns: + Converted model with differential attention. + + Raises: + ValueError: If number of heads is not even when using `split_heads` mode. + """ if config is None: config = LlamaDifferentialConfig(**model.config.__dict__) @@ -285,7 +373,14 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM): return new_model -def register_diff_attn(): +def register_diff_attn() -> None: + """ + Register differential attention components with the transformers library. + + This function registers the differential attention configurations and model classes + with the Auto* classes from `transformers`, making them available through the + standard model loading pipeline. + """ # Register configs AutoConfig.register("llama-differential", LlamaDifferentialConfig) diff --git a/src/axolotl/utils/callbacks/differential.py b/src/axolotl/utils/callbacks/differential.py new file mode 100644 index 000000000..49d9cd7e3 --- /dev/null +++ b/src/axolotl/utils/callbacks/differential.py @@ -0,0 +1,171 @@ +""" +Monitor and log differential attention components during training. + +This module provides a callback for tracking the behavior of differential attention +mechanisms, including lambda parameters and attention statistics. +""" + +from typing import Any + +import torch +import wandb +from transformers import TrainerCallback + +from axolotl.utils.distributed import is_main_process + + +class DifferentialAttentionMonitorCallback(TrainerCallback): + """ + Callback to monitor differential attention components and lambda parameters. + + This callback tracks attention statistics across all layers and provides detailed + monitoring for a specified number of layers evenly spaced through the model. + """ + + def __init__(self, log_every: int = 250, num_monitor_layers: int = 3): + """ + Initialize the differential attention monitor. + + Args: + log_every: Number of steps between logging events. + num_monitor_layers: Number of individual layers to monitor in detail. + """ + self.log_every = log_every + self.num_monitor_layers = num_monitor_layers + self.monitor_layers: list[int] | None = None # Will be set in on_train_begin + + # pylint: disable=unused-argument + def on_train_begin( + self, + args: Any, + state: Any, + control: Any, + model: torch.nn.Module, + **kwargs, + ) -> None: + """ + Set up layer monitoring at the start of training. + + Args: + args: Training arguments. + state: Training state. + control: Training control object. + model: The model being trained. + **kwargs: Additional arguments passed by the trainer. + """ + if is_main_process(): + num_layers = len(model.model.layers) + self.num_monitor_layers = min(self.num_monitor_layers, num_layers) + + stride = ( + (num_layers - 1) / (self.num_monitor_layers - 1) + if self.num_monitor_layers > 1 + else 0 + ) + self.monitor_layers = [ + round(i * stride) for i in range(self.num_monitor_layers) + ] + print(f"Monitoring layers {self.monitor_layers} in detail") + + # pylint: disable=unused-argument + def on_step_end( + self, args: Any, state: Any, control: Any, model: torch.nn.Module, **kwargs + ) -> None: + """ + Log attention metrics at the end of each step. + + Collects and logs: + - Lambda parameter norms and values. + - Attention statistics (mean and std). + - Both per-layer and aggregate metrics. + + Args: + args: Training arguments. + state: Training state. + control: Training control object. + model: The model being trained. + **kwargs: Additional arguments passed by the trainer. + """ + if not is_main_process() or state.global_step % self.log_every != 0: + return + + assert self.monitor_layers is not None + + # Aggregate stats across all layers + all_q1_norms = [] + all_q2_norms = [] + all_k1_norms = [] + all_k2_norms = [] + all_lambda1 = [] + all_lambda2 = [] + all_lambda_full = [] + + metrics = {} + + for layer_idx, layer in enumerate(model.model.layers): + attn = layer.self_attn + + # Collect stats for aggregation + all_q1_norms.append(attn.lambda_q1.norm().item()) + all_q2_norms.append(attn.lambda_q2.norm().item()) + all_k1_norms.append(attn.lambda_k1.norm().item()) + all_k2_norms.append(attn.lambda_k2.norm().item()) + + lambda1 = torch.exp(torch.sum(attn.lambda_q1 * attn.lambda_k1)).item() + lambda2 = torch.exp(torch.sum(attn.lambda_q2 * attn.lambda_k2)).item() + all_lambda1.append(lambda1) + all_lambda2.append(lambda2) + all_lambda_full.append(attn.lambda_full) + + # Log detailed metrics for monitored layers + if layer_idx in self.monitor_layers: + metrics.update( + { + f"layer_{layer_idx}/lambda_q1_norm": attn.lambda_q1.norm().item(), + f"layer_{layer_idx}/lambda_k1_norm": attn.lambda_k1.norm().item(), + f"layer_{layer_idx}/lambda_q2_norm": attn.lambda_q2.norm().item(), + f"layer_{layer_idx}/lambda_k2_norm": attn.lambda_k2.norm().item(), + f"layer_{layer_idx}/lambda1": lambda1, + f"layer_{layer_idx}/lambda2": lambda2, + f"layer_{layer_idx}/lambda_init": attn.lambda_init.item(), + f"layer_{layer_idx}/lambda_full": lambda1 + - lambda2 + + attn.lambda_init.item(), + f"layer_{layer_idx}/attn1_mean": attn.attn1.mean().item(), + f"layer_{layer_idx}/attn2_mean": attn.attn2.mean().item(), + f"layer_{layer_idx}/attn1_std": attn.attn1.std().item(), + f"layer_{layer_idx}/attn2_std": attn.attn2.std().item(), + } + ) + + # Add aggregate metrics + metrics.update( + { + "aggregate/lambda_q1_norm_mean": torch.tensor(all_q1_norms) + .mean() + .item(), + "aggregate/lambda_q1_norm_std": torch.tensor(all_q1_norms).std().item(), + "aggregate/lambda_q2_norm_mean": torch.tensor(all_q2_norms) + .mean() + .item(), + "aggregate/lambda_q2_norm_std": torch.tensor(all_q2_norms).std().item(), + "aggregate/lambda_k1_norm_mean": torch.tensor(all_k1_norms) + .mean() + .item(), + "aggregate/lambda_k1_norm_std": torch.tensor(all_k1_norms).std().item(), + "aggregate/lambda_k2_norm_mean": torch.tensor(all_k2_norms) + .mean() + .item(), + "aggregate/lambda_k2_norm_std": torch.tensor(all_k2_norms).std().item(), + "aggregate/lambda1_mean": torch.tensor(all_lambda1).mean().item(), + "aggregate/lambda1_std": torch.tensor(all_lambda1).std().item(), + "aggregate/lambda2_mean": torch.tensor(all_lambda2).mean().item(), + "aggregate/lambda2_std": torch.tensor(all_lambda2).std().item(), + "aggregate/lambda_full_mean": torch.tensor(all_lambda_full) + .mean() + .item(), + "aggregate/lambda_full_std": torch.tensor(all_lambda_full).std().item(), + } + ) + + wandb.log(metrics, step=state.global_step) diff --git a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py index 92c8053c0..e1ad31fdd 100644 --- a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py +++ b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py @@ -69,7 +69,7 @@ def test_conversion_cli_debug(tmp_path: Path, base_config): yaml.dump(base_config, file) cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs(debug=True, init_scale=0.1) + cli_args = ConvertDiffTransformerCliArgs(debug=True) _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) assert not debug_info["generations_match"]