adding diff attn negative component warmup (in progress)
This commit is contained in:
@@ -6,7 +6,10 @@ from typing import List
|
||||
from transformers import PreTrainedModel, TrainerCallback
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.callbacks.differential import DifferentialAttentionMonitorCallback
|
||||
from axolotl.utils.callbacks.differential import (
|
||||
DifferentialAttentionMixingCallback,
|
||||
DifferentialAttentionMonitorCallback,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -29,6 +32,7 @@ class DifferentialTransformerPlugin(BasePlugin):
|
||||
"""Returns module path to diff transformer plugin args for `axolotl` config."""
|
||||
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def add_callbacks_pre_trainer(
|
||||
self, cfg: DictDefault, model: PreTrainedModel
|
||||
) -> List[TrainerCallback]:
|
||||
@@ -49,6 +53,14 @@ class DifferentialTransformerPlugin(BasePlugin):
|
||||
DifferentialAttentionMonitorCallback(
|
||||
log_every=cfg.diff_attn_log_every,
|
||||
num_monitor_layers=cfg.diff_attn_num_monitor_layers,
|
||||
warmup_steps=cfg.diff_attn_warmup_steps,
|
||||
)
|
||||
)
|
||||
|
||||
if cfg.diff_attn_warmup_steps:
|
||||
callbacks.append(
|
||||
DifferentialAttentionMixingCallback(
|
||||
warmup_steps=cfg.diff_attn_warmup_steps
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -9,8 +9,19 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifferentialTransformerArgs(BaseModel):
|
||||
"""Input args for differential transformer."""
|
||||
"""
|
||||
Input args for differential transformer.
|
||||
|
||||
Attributes:
|
||||
diff_attention: Whether to use differential attention layers.
|
||||
diff_attn_log_every: How often to log differential attention statistics.
|
||||
diff_attn_num_monitor_layers: Number of layers to monitor for attention stats.
|
||||
diff_attn_warmup_steps: Number of steps to linearly increase negative attention
|
||||
mixing weight from 0 to 1. If specified, will reach full mixing at this
|
||||
step. If `None`, negative attention has full weight from the start.
|
||||
"""
|
||||
|
||||
diff_attention: Optional[bool] = None
|
||||
diff_attn_log_every: Optional[int] = 100
|
||||
diff_attn_num_monitor_layers: Optional[int] = 3
|
||||
diff_attn_warmup_steps: Optional[int] = None
|
||||
|
||||
@@ -71,19 +71,19 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
It supports both split heads and double projection modes for attention computation.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Any, layer_idx: int) -> None:
|
||||
def __init__(self, config: Any, layer_idx: int):
|
||||
"""
|
||||
Initializes the differential attention module.
|
||||
|
||||
Args:
|
||||
config: Model configuration object containing hyperparameters, including:
|
||||
- hidden_size: The size of hidden states
|
||||
- num_attention_heads: Number of attention heads
|
||||
- num_key_value_heads: Number of key/value heads
|
||||
- attention_bias: Whether to use bias in attention projections
|
||||
- split_heads: Whether to use split heads mode
|
||||
- rms_norm_eps: Epsilon for RMS normalization
|
||||
layer_idx: The index of this layer in the model
|
||||
- hidden_size: The size of hidden states.
|
||||
- num_attention_heads: Number of attention heads.
|
||||
- num_key_value_heads: Number of key/value heads.
|
||||
- attention_bias: Whether to use bias in attention projections.
|
||||
- split_heads: Whether to use split heads mode.
|
||||
- rms_norm_eps: Epsilon for RMS normalization.
|
||||
layer_idx: The index of this layer in the model.
|
||||
|
||||
Note:
|
||||
The initialization process consists of four steps:
|
||||
@@ -111,12 +111,11 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
dimension sizes and head counts based on the provided config. Handles both
|
||||
split heads and double projection modes.
|
||||
|
||||
In split heads mode, the number of heads is divided by 2 (rounding down), which
|
||||
differs from the original implementation that required an even number.
|
||||
|
||||
Args:
|
||||
layer_idx: Index of the current layer.
|
||||
|
||||
Note:
|
||||
In split heads mode, the number of heads is divided by 2 (rounding down),
|
||||
which differs from the original implementation that required an even number.
|
||||
"""
|
||||
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
self.base_num_heads = self.config.num_attention_heads
|
||||
@@ -170,10 +169,13 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
Initializes parameters specific to differential attention.
|
||||
|
||||
Creates learnable parameters for the differential attention mechanism:
|
||||
- Lambda parameters for queries and keys
|
||||
- Initial lambda value based on layer index
|
||||
- Rotary position embedding layer
|
||||
- Mixing parameter for negative attention component warmup phase.
|
||||
- Lambda parameters for queries and keys.
|
||||
- Initial lambda value based on layer index.
|
||||
- Rotary position embedding layer.
|
||||
"""
|
||||
self.diff_attn_mix = 0.0 # Default to full mixing
|
||||
|
||||
self.lambda_init = nn.Parameter(
|
||||
torch.full((), lambda_init_fn(self.layer_idx)),
|
||||
requires_grad=False,
|
||||
@@ -190,6 +192,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
self.lambda_k2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
||||
|
||||
def _init_normalization(self) -> None:
|
||||
@@ -344,8 +347,9 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
"""
|
||||
Computes lambda values for differential attention.
|
||||
|
||||
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are
|
||||
computed from the learned parameters.
|
||||
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are computed
|
||||
from the learned parameters. `diff_attn_mix` is multiplied through the result
|
||||
for negative attention component warmup phase (if applicable).
|
||||
|
||||
Args:
|
||||
q1: Positive attention query component, used for type casting.
|
||||
@@ -359,8 +363,9 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
lambda_2 = torch.exp(
|
||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
return lambda_1 - lambda_2 + self.lambda_init
|
||||
return self.diff_attn_mix * lambda_full
|
||||
|
||||
def _process_attention_output(
|
||||
self, attn: torch.Tensor, bsz: int, q_len: int
|
||||
@@ -378,7 +383,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
Processed attention output of shape (batch_size, seq_len, hidden_size)
|
||||
"""
|
||||
attn = self.subln(attn)
|
||||
attn = attn * (1 - self.lambda_init)
|
||||
attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
|
||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
|
||||
|
||||
return self.o_proj(attn)
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
import wandb
|
||||
from torch import nn
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
@@ -22,16 +23,23 @@ class DifferentialAttentionMonitorCallback(TrainerCallback):
|
||||
monitoring for a specified number of layers evenly spaced through the model.
|
||||
"""
|
||||
|
||||
def __init__(self, log_every: int = 250, num_monitor_layers: int = 3):
|
||||
def __init__(
|
||||
self,
|
||||
log_every: int = 250,
|
||||
num_monitor_layers: int = 3,
|
||||
warmup_steps: int | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the differential attention monitor.
|
||||
|
||||
Args:
|
||||
log_every: Number of steps between logging events.
|
||||
num_monitor_layers: Number of individual layers to monitor in detail.
|
||||
warmup_steps: Optional parameter for negative attention component warmup.
|
||||
"""
|
||||
self.log_every = log_every
|
||||
self.num_monitor_layers = num_monitor_layers
|
||||
self.warmup_steps = warmup_steps
|
||||
self.monitor_layers: list[int] | None = None # Will be set in on_train_begin
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@@ -101,7 +109,6 @@ class DifferentialAttentionMonitorCallback(TrainerCallback):
|
||||
all_lambda_full = []
|
||||
|
||||
metrics = {}
|
||||
|
||||
for layer_idx, layer in enumerate(model.model.layers):
|
||||
attn = layer.self_attn
|
||||
|
||||
@@ -168,4 +175,60 @@ class DifferentialAttentionMonitorCallback(TrainerCallback):
|
||||
}
|
||||
)
|
||||
|
||||
if self.warmup_steps:
|
||||
metrics["aggregate/diff_attn_mix"] = attn.diff_attn_mix
|
||||
|
||||
wandb.log(metrics, step=state.global_step)
|
||||
|
||||
|
||||
class DifferentialAttentionMixingCallback(TrainerCallback):
|
||||
"""
|
||||
Callback to gradually increase the weight of negative attention components during
|
||||
training.
|
||||
"""
|
||||
|
||||
def __init__(self, warmup_steps: int):
|
||||
"""
|
||||
Args:
|
||||
warmup_steps: Number of steps to linearly increase negative attention
|
||||
weight from 0 to 1. If `None`, negative attention has full weight from
|
||||
start.
|
||||
"""
|
||||
self.warmup_steps = warmup_steps
|
||||
self.diff_attention_layers: list[nn.Module] | None = None
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: Any,
|
||||
state: Any,
|
||||
control: Any,
|
||||
model: torch.nn.Module,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Cache the differential attention layers at the start of training."""
|
||||
if model is not None:
|
||||
# Get the actual model if it's wrapped
|
||||
if hasattr(model, "module"):
|
||||
model = model.module
|
||||
|
||||
# Cache all differential attention layers
|
||||
self.diff_attention_layers = [
|
||||
module for module in model.modules() if hasattr(module, "diff_attn_mix")
|
||||
]
|
||||
|
||||
def on_step_begin(
|
||||
self,
|
||||
args: Any,
|
||||
state: Any,
|
||||
control: Any,
|
||||
model: torch.nn.Module = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if self.diff_attention_layers and self.warmup_steps:
|
||||
# Calculate mixing parameter (0 to 1)
|
||||
mix = min(1.0, state.global_step / self.warmup_steps)
|
||||
|
||||
# Update cached layers
|
||||
for layer in self.diff_attention_layers:
|
||||
layer.diff_attn_mix = mix
|
||||
|
||||
Reference in New Issue
Block a user