adding diff attn negative component warmup (in progress)

This commit is contained in:
Dan Saunders
2025-01-10 21:57:31 +00:00
parent 6dd47edcb8
commit 661d71a14b
4 changed files with 114 additions and 23 deletions

View File

@@ -6,7 +6,10 @@ from typing import List
from transformers import PreTrainedModel, TrainerCallback
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.differential import DifferentialAttentionMonitorCallback
from axolotl.utils.callbacks.differential import (
DifferentialAttentionMixingCallback,
DifferentialAttentionMonitorCallback,
)
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
@@ -29,6 +32,7 @@ class DifferentialTransformerPlugin(BasePlugin):
"""Returns module path to diff transformer plugin args for `axolotl` config."""
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
# pylint: disable=unused-argument
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> List[TrainerCallback]:
@@ -49,6 +53,14 @@ class DifferentialTransformerPlugin(BasePlugin):
DifferentialAttentionMonitorCallback(
log_every=cfg.diff_attn_log_every,
num_monitor_layers=cfg.diff_attn_num_monitor_layers,
warmup_steps=cfg.diff_attn_warmup_steps,
)
)
if cfg.diff_attn_warmup_steps:
callbacks.append(
DifferentialAttentionMixingCallback(
warmup_steps=cfg.diff_attn_warmup_steps
)
)

View File

@@ -9,8 +9,19 @@ LOG = logging.getLogger(__name__)
class DifferentialTransformerArgs(BaseModel):
"""Input args for differential transformer."""
"""
Input args for differential transformer.
Attributes:
diff_attention: Whether to use differential attention layers.
diff_attn_log_every: How often to log differential attention statistics.
diff_attn_num_monitor_layers: Number of layers to monitor for attention stats.
diff_attn_warmup_steps: Number of steps to linearly increase negative attention
mixing weight from 0 to 1. If specified, will reach full mixing at this
step. If `None`, negative attention has full weight from the start.
"""
diff_attention: Optional[bool] = None
diff_attn_log_every: Optional[int] = 100
diff_attn_num_monitor_layers: Optional[int] = 3
diff_attn_warmup_steps: Optional[int] = None

View File

@@ -71,19 +71,19 @@ class LlamaDifferentialAttentionBase(nn.Module):
It supports both split heads and double projection modes for attention computation.
"""
def __init__(self, config: Any, layer_idx: int) -> None:
def __init__(self, config: Any, layer_idx: int):
"""
Initializes the differential attention module.
Args:
config: Model configuration object containing hyperparameters, including:
- hidden_size: The size of hidden states
- num_attention_heads: Number of attention heads
- num_key_value_heads: Number of key/value heads
- attention_bias: Whether to use bias in attention projections
- split_heads: Whether to use split heads mode
- rms_norm_eps: Epsilon for RMS normalization
layer_idx: The index of this layer in the model
- hidden_size: The size of hidden states.
- num_attention_heads: Number of attention heads.
- num_key_value_heads: Number of key/value heads.
- attention_bias: Whether to use bias in attention projections.
- split_heads: Whether to use split heads mode.
- rms_norm_eps: Epsilon for RMS normalization.
layer_idx: The index of this layer in the model.
Note:
The initialization process consists of four steps:
@@ -111,12 +111,11 @@ class LlamaDifferentialAttentionBase(nn.Module):
dimension sizes and head counts based on the provided config. Handles both
split heads and double projection modes.
In split heads mode, the number of heads is divided by 2 (rounding down), which
differs from the original implementation that required an even number.
Args:
layer_idx: Index of the current layer.
Note:
In split heads mode, the number of heads is divided by 2 (rounding down),
which differs from the original implementation that required an even number.
"""
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
self.base_num_heads = self.config.num_attention_heads
@@ -170,10 +169,13 @@ class LlamaDifferentialAttentionBase(nn.Module):
Initializes parameters specific to differential attention.
Creates learnable parameters for the differential attention mechanism:
- Lambda parameters for queries and keys
- Initial lambda value based on layer index
- Rotary position embedding layer
- Mixing parameter for negative attention component warmup phase.
- Lambda parameters for queries and keys.
- Initial lambda value based on layer index.
- Rotary position embedding layer.
"""
self.diff_attn_mix = 0.0 # Default to full mixing
self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx)),
requires_grad=False,
@@ -190,6 +192,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
def _init_normalization(self) -> None:
@@ -344,8 +347,9 @@ class LlamaDifferentialAttentionBase(nn.Module):
"""
Computes lambda values for differential attention.
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are
computed from the learned parameters.
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are computed
from the learned parameters. `diff_attn_mix` is multiplied through the result
for negative attention component warmup phase (if applicable).
Args:
q1: Positive attention query component, used for type casting.
@@ -359,8 +363,9 @@ class LlamaDifferentialAttentionBase(nn.Module):
lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
return lambda_1 - lambda_2 + self.lambda_init
return self.diff_attn_mix * lambda_full
def _process_attention_output(
self, attn: torch.Tensor, bsz: int, q_len: int
@@ -378,7 +383,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
Processed attention output of shape (batch_size, seq_len, hidden_size)
"""
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
return self.o_proj(attn)

View File

@@ -9,6 +9,7 @@ from typing import Any
import torch
import wandb
from torch import nn
from transformers import TrainerCallback
from axolotl.utils.distributed import is_main_process
@@ -22,16 +23,23 @@ class DifferentialAttentionMonitorCallback(TrainerCallback):
monitoring for a specified number of layers evenly spaced through the model.
"""
def __init__(self, log_every: int = 250, num_monitor_layers: int = 3):
def __init__(
self,
log_every: int = 250,
num_monitor_layers: int = 3,
warmup_steps: int | None = None,
):
"""
Initialize the differential attention monitor.
Args:
log_every: Number of steps between logging events.
num_monitor_layers: Number of individual layers to monitor in detail.
warmup_steps: Optional parameter for negative attention component warmup.
"""
self.log_every = log_every
self.num_monitor_layers = num_monitor_layers
self.warmup_steps = warmup_steps
self.monitor_layers: list[int] | None = None # Will be set in on_train_begin
# pylint: disable=unused-argument
@@ -101,7 +109,6 @@ class DifferentialAttentionMonitorCallback(TrainerCallback):
all_lambda_full = []
metrics = {}
for layer_idx, layer in enumerate(model.model.layers):
attn = layer.self_attn
@@ -168,4 +175,60 @@ class DifferentialAttentionMonitorCallback(TrainerCallback):
}
)
if self.warmup_steps:
metrics["aggregate/diff_attn_mix"] = attn.diff_attn_mix
wandb.log(metrics, step=state.global_step)
class DifferentialAttentionMixingCallback(TrainerCallback):
"""
Callback to gradually increase the weight of negative attention components during
training.
"""
def __init__(self, warmup_steps: int):
"""
Args:
warmup_steps: Number of steps to linearly increase negative attention
weight from 0 to 1. If `None`, negative attention has full weight from
start.
"""
self.warmup_steps = warmup_steps
self.diff_attention_layers: list[nn.Module] | None = None
# pylint: disable=unused-argument
def on_train_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module,
**kwargs,
) -> None:
"""Cache the differential attention layers at the start of training."""
if model is not None:
# Get the actual model if it's wrapped
if hasattr(model, "module"):
model = model.module
# Cache all differential attention layers
self.diff_attention_layers = [
module for module in model.modules() if hasattr(module, "diff_attn_mix")
]
def on_step_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module = None,
**kwargs,
) -> None:
if self.diff_attention_layers and self.warmup_steps:
# Calculate mixing parameter (0 to 1)
mix = min(1.0, state.global_step / self.warmup_steps)
# Update cached layers
for layer in self.diff_attention_layers:
layer.diff_attn_mix = mix