adding diff attn negative component warmup (in progress)
This commit is contained in:
@@ -6,7 +6,10 @@ from typing import List
|
|||||||
from transformers import PreTrainedModel, TrainerCallback
|
from transformers import PreTrainedModel, TrainerCallback
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.utils.callbacks.differential import DifferentialAttentionMonitorCallback
|
from axolotl.utils.callbacks.differential import (
|
||||||
|
DifferentialAttentionMixingCallback,
|
||||||
|
DifferentialAttentionMonitorCallback,
|
||||||
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -29,6 +32,7 @@ class DifferentialTransformerPlugin(BasePlugin):
|
|||||||
"""Returns module path to diff transformer plugin args for `axolotl` config."""
|
"""Returns module path to diff transformer plugin args for `axolotl` config."""
|
||||||
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
|
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
def add_callbacks_pre_trainer(
|
def add_callbacks_pre_trainer(
|
||||||
self, cfg: DictDefault, model: PreTrainedModel
|
self, cfg: DictDefault, model: PreTrainedModel
|
||||||
) -> List[TrainerCallback]:
|
) -> List[TrainerCallback]:
|
||||||
@@ -49,6 +53,14 @@ class DifferentialTransformerPlugin(BasePlugin):
|
|||||||
DifferentialAttentionMonitorCallback(
|
DifferentialAttentionMonitorCallback(
|
||||||
log_every=cfg.diff_attn_log_every,
|
log_every=cfg.diff_attn_log_every,
|
||||||
num_monitor_layers=cfg.diff_attn_num_monitor_layers,
|
num_monitor_layers=cfg.diff_attn_num_monitor_layers,
|
||||||
|
warmup_steps=cfg.diff_attn_warmup_steps,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.diff_attn_warmup_steps:
|
||||||
|
callbacks.append(
|
||||||
|
DifferentialAttentionMixingCallback(
|
||||||
|
warmup_steps=cfg.diff_attn_warmup_steps
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,19 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class DifferentialTransformerArgs(BaseModel):
|
class DifferentialTransformerArgs(BaseModel):
|
||||||
"""Input args for differential transformer."""
|
"""
|
||||||
|
Input args for differential transformer.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
diff_attention: Whether to use differential attention layers.
|
||||||
|
diff_attn_log_every: How often to log differential attention statistics.
|
||||||
|
diff_attn_num_monitor_layers: Number of layers to monitor for attention stats.
|
||||||
|
diff_attn_warmup_steps: Number of steps to linearly increase negative attention
|
||||||
|
mixing weight from 0 to 1. If specified, will reach full mixing at this
|
||||||
|
step. If `None`, negative attention has full weight from the start.
|
||||||
|
"""
|
||||||
|
|
||||||
diff_attention: Optional[bool] = None
|
diff_attention: Optional[bool] = None
|
||||||
diff_attn_log_every: Optional[int] = 100
|
diff_attn_log_every: Optional[int] = 100
|
||||||
diff_attn_num_monitor_layers: Optional[int] = 3
|
diff_attn_num_monitor_layers: Optional[int] = 3
|
||||||
|
diff_attn_warmup_steps: Optional[int] = None
|
||||||
|
|||||||
@@ -71,19 +71,19 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
It supports both split heads and double projection modes for attention computation.
|
It supports both split heads and double projection modes for attention computation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: Any, layer_idx: int) -> None:
|
def __init__(self, config: Any, layer_idx: int):
|
||||||
"""
|
"""
|
||||||
Initializes the differential attention module.
|
Initializes the differential attention module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Model configuration object containing hyperparameters, including:
|
config: Model configuration object containing hyperparameters, including:
|
||||||
- hidden_size: The size of hidden states
|
- hidden_size: The size of hidden states.
|
||||||
- num_attention_heads: Number of attention heads
|
- num_attention_heads: Number of attention heads.
|
||||||
- num_key_value_heads: Number of key/value heads
|
- num_key_value_heads: Number of key/value heads.
|
||||||
- attention_bias: Whether to use bias in attention projections
|
- attention_bias: Whether to use bias in attention projections.
|
||||||
- split_heads: Whether to use split heads mode
|
- split_heads: Whether to use split heads mode.
|
||||||
- rms_norm_eps: Epsilon for RMS normalization
|
- rms_norm_eps: Epsilon for RMS normalization.
|
||||||
layer_idx: The index of this layer in the model
|
layer_idx: The index of this layer in the model.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
The initialization process consists of four steps:
|
The initialization process consists of four steps:
|
||||||
@@ -111,12 +111,11 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
dimension sizes and head counts based on the provided config. Handles both
|
dimension sizes and head counts based on the provided config. Handles both
|
||||||
split heads and double projection modes.
|
split heads and double projection modes.
|
||||||
|
|
||||||
|
In split heads mode, the number of heads is divided by 2 (rounding down), which
|
||||||
|
differs from the original implementation that required an even number.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_idx: Index of the current layer.
|
layer_idx: Index of the current layer.
|
||||||
|
|
||||||
Note:
|
|
||||||
In split heads mode, the number of heads is divided by 2 (rounding down),
|
|
||||||
which differs from the original implementation that required an even number.
|
|
||||||
"""
|
"""
|
||||||
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||||
self.base_num_heads = self.config.num_attention_heads
|
self.base_num_heads = self.config.num_attention_heads
|
||||||
@@ -170,10 +169,13 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
Initializes parameters specific to differential attention.
|
Initializes parameters specific to differential attention.
|
||||||
|
|
||||||
Creates learnable parameters for the differential attention mechanism:
|
Creates learnable parameters for the differential attention mechanism:
|
||||||
- Lambda parameters for queries and keys
|
- Mixing parameter for negative attention component warmup phase.
|
||||||
- Initial lambda value based on layer index
|
- Lambda parameters for queries and keys.
|
||||||
- Rotary position embedding layer
|
- Initial lambda value based on layer index.
|
||||||
|
- Rotary position embedding layer.
|
||||||
"""
|
"""
|
||||||
|
self.diff_attn_mix = 0.0 # Default to full mixing
|
||||||
|
|
||||||
self.lambda_init = nn.Parameter(
|
self.lambda_init = nn.Parameter(
|
||||||
torch.full((), lambda_init_fn(self.layer_idx)),
|
torch.full((), lambda_init_fn(self.layer_idx)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
@@ -190,6 +192,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
self.lambda_k2 = nn.Parameter(
|
self.lambda_k2 = nn.Parameter(
|
||||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
||||||
|
|
||||||
def _init_normalization(self) -> None:
|
def _init_normalization(self) -> None:
|
||||||
@@ -344,8 +347,9 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Computes lambda values for differential attention.
|
Computes lambda values for differential attention.
|
||||||
|
|
||||||
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are
|
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are computed
|
||||||
computed from the learned parameters.
|
from the learned parameters. `diff_attn_mix` is multiplied through the result
|
||||||
|
for negative attention component warmup phase (if applicable).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q1: Positive attention query component, used for type casting.
|
q1: Positive attention query component, used for type casting.
|
||||||
@@ -359,8 +363,9 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
lambda_2 = torch.exp(
|
lambda_2 = torch.exp(
|
||||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||||
).type_as(q1)
|
).type_as(q1)
|
||||||
|
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||||
|
|
||||||
return lambda_1 - lambda_2 + self.lambda_init
|
return self.diff_attn_mix * lambda_full
|
||||||
|
|
||||||
def _process_attention_output(
|
def _process_attention_output(
|
||||||
self, attn: torch.Tensor, bsz: int, q_len: int
|
self, attn: torch.Tensor, bsz: int, q_len: int
|
||||||
@@ -378,7 +383,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
Processed attention output of shape (batch_size, seq_len, hidden_size)
|
Processed attention output of shape (batch_size, seq_len, hidden_size)
|
||||||
"""
|
"""
|
||||||
attn = self.subln(attn)
|
attn = self.subln(attn)
|
||||||
attn = attn * (1 - self.lambda_init)
|
attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
|
||||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
|
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
|
||||||
|
|
||||||
return self.o_proj(attn)
|
return self.o_proj(attn)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from typing import Any
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import wandb
|
import wandb
|
||||||
|
from torch import nn
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
@@ -22,16 +23,23 @@ class DifferentialAttentionMonitorCallback(TrainerCallback):
|
|||||||
monitoring for a specified number of layers evenly spaced through the model.
|
monitoring for a specified number of layers evenly spaced through the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, log_every: int = 250, num_monitor_layers: int = 3):
|
def __init__(
|
||||||
|
self,
|
||||||
|
log_every: int = 250,
|
||||||
|
num_monitor_layers: int = 3,
|
||||||
|
warmup_steps: int | None = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the differential attention monitor.
|
Initialize the differential attention monitor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
log_every: Number of steps between logging events.
|
log_every: Number of steps between logging events.
|
||||||
num_monitor_layers: Number of individual layers to monitor in detail.
|
num_monitor_layers: Number of individual layers to monitor in detail.
|
||||||
|
warmup_steps: Optional parameter for negative attention component warmup.
|
||||||
"""
|
"""
|
||||||
self.log_every = log_every
|
self.log_every = log_every
|
||||||
self.num_monitor_layers = num_monitor_layers
|
self.num_monitor_layers = num_monitor_layers
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
self.monitor_layers: list[int] | None = None # Will be set in on_train_begin
|
self.monitor_layers: list[int] | None = None # Will be set in on_train_begin
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
@@ -101,7 +109,6 @@ class DifferentialAttentionMonitorCallback(TrainerCallback):
|
|||||||
all_lambda_full = []
|
all_lambda_full = []
|
||||||
|
|
||||||
metrics = {}
|
metrics = {}
|
||||||
|
|
||||||
for layer_idx, layer in enumerate(model.model.layers):
|
for layer_idx, layer in enumerate(model.model.layers):
|
||||||
attn = layer.self_attn
|
attn = layer.self_attn
|
||||||
|
|
||||||
@@ -168,4 +175,60 @@ class DifferentialAttentionMonitorCallback(TrainerCallback):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.warmup_steps:
|
||||||
|
metrics["aggregate/diff_attn_mix"] = attn.diff_attn_mix
|
||||||
|
|
||||||
wandb.log(metrics, step=state.global_step)
|
wandb.log(metrics, step=state.global_step)
|
||||||
|
|
||||||
|
|
||||||
|
class DifferentialAttentionMixingCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
Callback to gradually increase the weight of negative attention components during
|
||||||
|
training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, warmup_steps: int):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
warmup_steps: Number of steps to linearly increase negative attention
|
||||||
|
weight from 0 to 1. If `None`, negative attention has full weight from
|
||||||
|
start.
|
||||||
|
"""
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.diff_attention_layers: list[nn.Module] | None = None
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: Any,
|
||||||
|
state: Any,
|
||||||
|
control: Any,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Cache the differential attention layers at the start of training."""
|
||||||
|
if model is not None:
|
||||||
|
# Get the actual model if it's wrapped
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
model = model.module
|
||||||
|
|
||||||
|
# Cache all differential attention layers
|
||||||
|
self.diff_attention_layers = [
|
||||||
|
module for module in model.modules() if hasattr(module, "diff_attn_mix")
|
||||||
|
]
|
||||||
|
|
||||||
|
def on_step_begin(
|
||||||
|
self,
|
||||||
|
args: Any,
|
||||||
|
state: Any,
|
||||||
|
control: Any,
|
||||||
|
model: torch.nn.Module = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
if self.diff_attention_layers and self.warmup_steps:
|
||||||
|
# Calculate mixing parameter (0 to 1)
|
||||||
|
mix = min(1.0, state.global_step / self.warmup_steps)
|
||||||
|
|
||||||
|
# Update cached layers
|
||||||
|
for layer in self.diff_attention_layers:
|
||||||
|
layer.diff_attn_mix = mix
|
||||||
|
|||||||
Reference in New Issue
Block a user