adding diff attn negative component warmup (in progress)

This commit is contained in:
Dan Saunders
2025-01-10 21:57:31 +00:00
parent 6dd47edcb8
commit 661d71a14b
4 changed files with 114 additions and 23 deletions

View File

@@ -6,7 +6,10 @@ from typing import List
from transformers import PreTrainedModel, TrainerCallback from transformers import PreTrainedModel, TrainerCallback
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.differential import DifferentialAttentionMonitorCallback from axolotl.utils.callbacks.differential import (
DifferentialAttentionMixingCallback,
DifferentialAttentionMonitorCallback,
)
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -29,6 +32,7 @@ class DifferentialTransformerPlugin(BasePlugin):
"""Returns module path to diff transformer plugin args for `axolotl` config.""" """Returns module path to diff transformer plugin args for `axolotl` config."""
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs" return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
# pylint: disable=unused-argument
def add_callbacks_pre_trainer( def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel self, cfg: DictDefault, model: PreTrainedModel
) -> List[TrainerCallback]: ) -> List[TrainerCallback]:
@@ -49,6 +53,14 @@ class DifferentialTransformerPlugin(BasePlugin):
DifferentialAttentionMonitorCallback( DifferentialAttentionMonitorCallback(
log_every=cfg.diff_attn_log_every, log_every=cfg.diff_attn_log_every,
num_monitor_layers=cfg.diff_attn_num_monitor_layers, num_monitor_layers=cfg.diff_attn_num_monitor_layers,
warmup_steps=cfg.diff_attn_warmup_steps,
)
)
if cfg.diff_attn_warmup_steps:
callbacks.append(
DifferentialAttentionMixingCallback(
warmup_steps=cfg.diff_attn_warmup_steps
) )
) )

View File

@@ -9,8 +9,19 @@ LOG = logging.getLogger(__name__)
class DifferentialTransformerArgs(BaseModel): class DifferentialTransformerArgs(BaseModel):
"""Input args for differential transformer.""" """
Input args for differential transformer.
Attributes:
diff_attention: Whether to use differential attention layers.
diff_attn_log_every: How often to log differential attention statistics.
diff_attn_num_monitor_layers: Number of layers to monitor for attention stats.
diff_attn_warmup_steps: Number of steps to linearly increase negative attention
mixing weight from 0 to 1. If specified, will reach full mixing at this
step. If `None`, negative attention has full weight from the start.
"""
diff_attention: Optional[bool] = None diff_attention: Optional[bool] = None
diff_attn_log_every: Optional[int] = 100 diff_attn_log_every: Optional[int] = 100
diff_attn_num_monitor_layers: Optional[int] = 3 diff_attn_num_monitor_layers: Optional[int] = 3
diff_attn_warmup_steps: Optional[int] = None

View File

@@ -71,19 +71,19 @@ class LlamaDifferentialAttentionBase(nn.Module):
It supports both split heads and double projection modes for attention computation. It supports both split heads and double projection modes for attention computation.
""" """
def __init__(self, config: Any, layer_idx: int) -> None: def __init__(self, config: Any, layer_idx: int):
""" """
Initializes the differential attention module. Initializes the differential attention module.
Args: Args:
config: Model configuration object containing hyperparameters, including: config: Model configuration object containing hyperparameters, including:
- hidden_size: The size of hidden states - hidden_size: The size of hidden states.
- num_attention_heads: Number of attention heads - num_attention_heads: Number of attention heads.
- num_key_value_heads: Number of key/value heads - num_key_value_heads: Number of key/value heads.
- attention_bias: Whether to use bias in attention projections - attention_bias: Whether to use bias in attention projections.
- split_heads: Whether to use split heads mode - split_heads: Whether to use split heads mode.
- rms_norm_eps: Epsilon for RMS normalization - rms_norm_eps: Epsilon for RMS normalization.
layer_idx: The index of this layer in the model layer_idx: The index of this layer in the model.
Note: Note:
The initialization process consists of four steps: The initialization process consists of four steps:
@@ -111,12 +111,11 @@ class LlamaDifferentialAttentionBase(nn.Module):
dimension sizes and head counts based on the provided config. Handles both dimension sizes and head counts based on the provided config. Handles both
split heads and double projection modes. split heads and double projection modes.
In split heads mode, the number of heads is divided by 2 (rounding down), which
differs from the original implementation that required an even number.
Args: Args:
layer_idx: Index of the current layer. layer_idx: Index of the current layer.
Note:
In split heads mode, the number of heads is divided by 2 (rounding down),
which differs from the original implementation that required an even number.
""" """
self.head_dim = self.config.hidden_size // self.config.num_attention_heads self.head_dim = self.config.hidden_size // self.config.num_attention_heads
self.base_num_heads = self.config.num_attention_heads self.base_num_heads = self.config.num_attention_heads
@@ -170,10 +169,13 @@ class LlamaDifferentialAttentionBase(nn.Module):
Initializes parameters specific to differential attention. Initializes parameters specific to differential attention.
Creates learnable parameters for the differential attention mechanism: Creates learnable parameters for the differential attention mechanism:
- Lambda parameters for queries and keys - Mixing parameter for negative attention component warmup phase.
- Initial lambda value based on layer index - Lambda parameters for queries and keys.
- Rotary position embedding layer - Initial lambda value based on layer index.
- Rotary position embedding layer.
""" """
self.diff_attn_mix = 0.0 # Default to full mixing
self.lambda_init = nn.Parameter( self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx)), torch.full((), lambda_init_fn(self.layer_idx)),
requires_grad=False, requires_grad=False,
@@ -190,6 +192,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
self.lambda_k2 = nn.Parameter( self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1) torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
) )
self.rotary_emb = LlamaRotaryEmbedding(config=self.config) self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
def _init_normalization(self) -> None: def _init_normalization(self) -> None:
@@ -344,8 +347,9 @@ class LlamaDifferentialAttentionBase(nn.Module):
""" """
Computes lambda values for differential attention. Computes lambda values for differential attention.
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are computed
computed from the learned parameters. from the learned parameters. `diff_attn_mix` is multiplied through the result
for negative attention component warmup phase (if applicable).
Args: Args:
q1: Positive attention query component, used for type casting. q1: Positive attention query component, used for type casting.
@@ -359,8 +363,9 @@ class LlamaDifferentialAttentionBase(nn.Module):
lambda_2 = torch.exp( lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1) ).type_as(q1)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
return lambda_1 - lambda_2 + self.lambda_init return self.diff_attn_mix * lambda_full
def _process_attention_output( def _process_attention_output(
self, attn: torch.Tensor, bsz: int, q_len: int self, attn: torch.Tensor, bsz: int, q_len: int
@@ -378,7 +383,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
Processed attention output of shape (batch_size, seq_len, hidden_size) Processed attention output of shape (batch_size, seq_len, hidden_size)
""" """
attn = self.subln(attn) attn = self.subln(attn)
attn = attn * (1 - self.lambda_init) attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size) attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
return self.o_proj(attn) return self.o_proj(attn)

View File

@@ -9,6 +9,7 @@ from typing import Any
import torch import torch
import wandb import wandb
from torch import nn
from transformers import TrainerCallback from transformers import TrainerCallback
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
@@ -22,16 +23,23 @@ class DifferentialAttentionMonitorCallback(TrainerCallback):
monitoring for a specified number of layers evenly spaced through the model. monitoring for a specified number of layers evenly spaced through the model.
""" """
def __init__(self, log_every: int = 250, num_monitor_layers: int = 3): def __init__(
self,
log_every: int = 250,
num_monitor_layers: int = 3,
warmup_steps: int | None = None,
):
""" """
Initialize the differential attention monitor. Initialize the differential attention monitor.
Args: Args:
log_every: Number of steps between logging events. log_every: Number of steps between logging events.
num_monitor_layers: Number of individual layers to monitor in detail. num_monitor_layers: Number of individual layers to monitor in detail.
warmup_steps: Optional parameter for negative attention component warmup.
""" """
self.log_every = log_every self.log_every = log_every
self.num_monitor_layers = num_monitor_layers self.num_monitor_layers = num_monitor_layers
self.warmup_steps = warmup_steps
self.monitor_layers: list[int] | None = None # Will be set in on_train_begin self.monitor_layers: list[int] | None = None # Will be set in on_train_begin
# pylint: disable=unused-argument # pylint: disable=unused-argument
@@ -101,7 +109,6 @@ class DifferentialAttentionMonitorCallback(TrainerCallback):
all_lambda_full = [] all_lambda_full = []
metrics = {} metrics = {}
for layer_idx, layer in enumerate(model.model.layers): for layer_idx, layer in enumerate(model.model.layers):
attn = layer.self_attn attn = layer.self_attn
@@ -168,4 +175,60 @@ class DifferentialAttentionMonitorCallback(TrainerCallback):
} }
) )
if self.warmup_steps:
metrics["aggregate/diff_attn_mix"] = attn.diff_attn_mix
wandb.log(metrics, step=state.global_step) wandb.log(metrics, step=state.global_step)
class DifferentialAttentionMixingCallback(TrainerCallback):
"""
Callback to gradually increase the weight of negative attention components during
training.
"""
def __init__(self, warmup_steps: int):
"""
Args:
warmup_steps: Number of steps to linearly increase negative attention
weight from 0 to 1. If `None`, negative attention has full weight from
start.
"""
self.warmup_steps = warmup_steps
self.diff_attention_layers: list[nn.Module] | None = None
# pylint: disable=unused-argument
def on_train_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module,
**kwargs,
) -> None:
"""Cache the differential attention layers at the start of training."""
if model is not None:
# Get the actual model if it's wrapped
if hasattr(model, "module"):
model = model.module
# Cache all differential attention layers
self.diff_attention_layers = [
module for module in model.modules() if hasattr(module, "diff_attn_mix")
]
def on_step_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module = None,
**kwargs,
) -> None:
if self.diff_attention_layers and self.warmup_steps:
# Calculate mixing parameter (0 to 1)
mix = min(1.0, state.global_step / self.warmup_steps)
# Update cached layers
for layer in self.diff_attention_layers:
layer.diff_attn_mix = mix