From 661d71a14b8866ed80f5a91c8213d70a35cfbba6 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 10 Jan 2025 21:57:31 +0000 Subject: [PATCH] adding diff attn negative component warmup (in progress) --- .../integrations/diff_transformer/__init__.py | 14 +++- .../integrations/diff_transformer/args.py | 13 +++- .../diff_transformer/diff_attn.py | 43 ++++++------ src/axolotl/utils/callbacks/differential.py | 67 ++++++++++++++++++- 4 files changed, 114 insertions(+), 23 deletions(-) diff --git a/src/axolotl/integrations/diff_transformer/__init__.py b/src/axolotl/integrations/diff_transformer/__init__.py index e0ebec818..fdacd2b3a 100644 --- a/src/axolotl/integrations/diff_transformer/__init__.py +++ b/src/axolotl/integrations/diff_transformer/__init__.py @@ -6,7 +6,10 @@ from typing import List from transformers import PreTrainedModel, TrainerCallback from axolotl.integrations.base import BasePlugin -from axolotl.utils.callbacks.differential import DifferentialAttentionMonitorCallback +from axolotl.utils.callbacks.differential import ( + DifferentialAttentionMixingCallback, + DifferentialAttentionMonitorCallback, +) from axolotl.utils.dict import DictDefault LOG = logging.getLogger(__name__) @@ -29,6 +32,7 @@ class DifferentialTransformerPlugin(BasePlugin): """Returns module path to diff transformer plugin args for `axolotl` config.""" return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs" + # pylint: disable=unused-argument def add_callbacks_pre_trainer( self, cfg: DictDefault, model: PreTrainedModel ) -> List[TrainerCallback]: @@ -49,6 +53,14 @@ class DifferentialTransformerPlugin(BasePlugin): DifferentialAttentionMonitorCallback( log_every=cfg.diff_attn_log_every, num_monitor_layers=cfg.diff_attn_num_monitor_layers, + warmup_steps=cfg.diff_attn_warmup_steps, + ) + ) + + if cfg.diff_attn_warmup_steps: + callbacks.append( + DifferentialAttentionMixingCallback( + warmup_steps=cfg.diff_attn_warmup_steps ) ) diff --git a/src/axolotl/integrations/diff_transformer/args.py b/src/axolotl/integrations/diff_transformer/args.py index fe3e7d977..ebd4d03a1 100644 --- a/src/axolotl/integrations/diff_transformer/args.py +++ b/src/axolotl/integrations/diff_transformer/args.py @@ -9,8 +9,19 @@ LOG = logging.getLogger(__name__) class DifferentialTransformerArgs(BaseModel): - """Input args for differential transformer.""" + """ + Input args for differential transformer. + + Attributes: + diff_attention: Whether to use differential attention layers. + diff_attn_log_every: How often to log differential attention statistics. + diff_attn_num_monitor_layers: Number of layers to monitor for attention stats. + diff_attn_warmup_steps: Number of steps to linearly increase negative attention + mixing weight from 0 to 1. If specified, will reach full mixing at this + step. If `None`, negative attention has full weight from the start. + """ diff_attention: Optional[bool] = None diff_attn_log_every: Optional[int] = 100 diff_attn_num_monitor_layers: Optional[int] = 3 + diff_attn_warmup_steps: Optional[int] = None diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py index cc3e8f90c..3744b0df5 100644 --- a/src/axolotl/integrations/diff_transformer/diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -71,19 +71,19 @@ class LlamaDifferentialAttentionBase(nn.Module): It supports both split heads and double projection modes for attention computation. """ - def __init__(self, config: Any, layer_idx: int) -> None: + def __init__(self, config: Any, layer_idx: int): """ Initializes the differential attention module. Args: config: Model configuration object containing hyperparameters, including: - - hidden_size: The size of hidden states - - num_attention_heads: Number of attention heads - - num_key_value_heads: Number of key/value heads - - attention_bias: Whether to use bias in attention projections - - split_heads: Whether to use split heads mode - - rms_norm_eps: Epsilon for RMS normalization - layer_idx: The index of this layer in the model + - hidden_size: The size of hidden states. + - num_attention_heads: Number of attention heads. + - num_key_value_heads: Number of key/value heads. + - attention_bias: Whether to use bias in attention projections. + - split_heads: Whether to use split heads mode. + - rms_norm_eps: Epsilon for RMS normalization. + layer_idx: The index of this layer in the model. Note: The initialization process consists of four steps: @@ -111,12 +111,11 @@ class LlamaDifferentialAttentionBase(nn.Module): dimension sizes and head counts based on the provided config. Handles both split heads and double projection modes. + In split heads mode, the number of heads is divided by 2 (rounding down), which + differs from the original implementation that required an even number. + Args: layer_idx: Index of the current layer. - - Note: - In split heads mode, the number of heads is divided by 2 (rounding down), - which differs from the original implementation that required an even number. """ self.head_dim = self.config.hidden_size // self.config.num_attention_heads self.base_num_heads = self.config.num_attention_heads @@ -170,10 +169,13 @@ class LlamaDifferentialAttentionBase(nn.Module): Initializes parameters specific to differential attention. Creates learnable parameters for the differential attention mechanism: - - Lambda parameters for queries and keys - - Initial lambda value based on layer index - - Rotary position embedding layer + - Mixing parameter for negative attention component warmup phase. + - Lambda parameters for queries and keys. + - Initial lambda value based on layer index. + - Rotary position embedding layer. """ + self.diff_attn_mix = 0.0 # Default to full mixing + self.lambda_init = nn.Parameter( torch.full((), lambda_init_fn(self.layer_idx)), requires_grad=False, @@ -190,6 +192,7 @@ class LlamaDifferentialAttentionBase(nn.Module): self.lambda_k2 = nn.Parameter( torch.zeros(self.head_dim).normal_(mean=0, std=0.1) ) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) def _init_normalization(self) -> None: @@ -344,8 +347,9 @@ class LlamaDifferentialAttentionBase(nn.Module): """ Computes lambda values for differential attention. - The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are - computed from the learned parameters. + The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are computed + from the learned parameters. `diff_attn_mix` is multiplied through the result + for negative attention component warmup phase (if applicable). Args: q1: Positive attention query component, used for type casting. @@ -359,8 +363,9 @@ class LlamaDifferentialAttentionBase(nn.Module): lambda_2 = torch.exp( torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() ).type_as(q1) + lambda_full = lambda_1 - lambda_2 + self.lambda_init - return lambda_1 - lambda_2 + self.lambda_init + return self.diff_attn_mix * lambda_full def _process_attention_output( self, attn: torch.Tensor, bsz: int, q_len: int @@ -378,7 +383,7 @@ class LlamaDifferentialAttentionBase(nn.Module): Processed attention output of shape (batch_size, seq_len, hidden_size) """ attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) + attn = attn * self.diff_attn_mix * (1 - self.lambda_init) attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size) return self.o_proj(attn) diff --git a/src/axolotl/utils/callbacks/differential.py b/src/axolotl/utils/callbacks/differential.py index 49d9cd7e3..3e99e7d5f 100644 --- a/src/axolotl/utils/callbacks/differential.py +++ b/src/axolotl/utils/callbacks/differential.py @@ -9,6 +9,7 @@ from typing import Any import torch import wandb +from torch import nn from transformers import TrainerCallback from axolotl.utils.distributed import is_main_process @@ -22,16 +23,23 @@ class DifferentialAttentionMonitorCallback(TrainerCallback): monitoring for a specified number of layers evenly spaced through the model. """ - def __init__(self, log_every: int = 250, num_monitor_layers: int = 3): + def __init__( + self, + log_every: int = 250, + num_monitor_layers: int = 3, + warmup_steps: int | None = None, + ): """ Initialize the differential attention monitor. Args: log_every: Number of steps between logging events. num_monitor_layers: Number of individual layers to monitor in detail. + warmup_steps: Optional parameter for negative attention component warmup. """ self.log_every = log_every self.num_monitor_layers = num_monitor_layers + self.warmup_steps = warmup_steps self.monitor_layers: list[int] | None = None # Will be set in on_train_begin # pylint: disable=unused-argument @@ -101,7 +109,6 @@ class DifferentialAttentionMonitorCallback(TrainerCallback): all_lambda_full = [] metrics = {} - for layer_idx, layer in enumerate(model.model.layers): attn = layer.self_attn @@ -168,4 +175,60 @@ class DifferentialAttentionMonitorCallback(TrainerCallback): } ) + if self.warmup_steps: + metrics["aggregate/diff_attn_mix"] = attn.diff_attn_mix + wandb.log(metrics, step=state.global_step) + + +class DifferentialAttentionMixingCallback(TrainerCallback): + """ + Callback to gradually increase the weight of negative attention components during + training. + """ + + def __init__(self, warmup_steps: int): + """ + Args: + warmup_steps: Number of steps to linearly increase negative attention + weight from 0 to 1. If `None`, negative attention has full weight from + start. + """ + self.warmup_steps = warmup_steps + self.diff_attention_layers: list[nn.Module] | None = None + + # pylint: disable=unused-argument + def on_train_begin( + self, + args: Any, + state: Any, + control: Any, + model: torch.nn.Module, + **kwargs, + ) -> None: + """Cache the differential attention layers at the start of training.""" + if model is not None: + # Get the actual model if it's wrapped + if hasattr(model, "module"): + model = model.module + + # Cache all differential attention layers + self.diff_attention_layers = [ + module for module in model.modules() if hasattr(module, "diff_attn_mix") + ] + + def on_step_begin( + self, + args: Any, + state: Any, + control: Any, + model: torch.nn.Module = None, + **kwargs, + ) -> None: + if self.diff_attention_layers and self.warmup_steps: + # Calculate mixing parameter (0 to 1) + mix = min(1.0, state.global_step / self.warmup_steps) + + # Update cached layers + for layer in self.diff_attention_layers: + layer.diff_attn_mix = mix