adding diff attn callback, adding documentation
This commit is contained in:
@@ -202,7 +202,7 @@ def do_inference(
|
|||||||
)
|
)
|
||||||
elif cfg.chat_template:
|
elif cfg.chat_template:
|
||||||
chat_template_str = get_chat_template(cfg.chat_template)
|
chat_template_str = get_chat_template(cfg.chat_template)
|
||||||
elif cfg.datasets[0].type == "chat_template":
|
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
|
||||||
chat_template_str = get_chat_template_from_config(
|
chat_template_str = get_chat_template_from_config(
|
||||||
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -87,8 +87,6 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
|||||||
zero_init=cli_args.zero_init,
|
zero_init=cli_args.zero_init,
|
||||||
sublayer_norm=cli_args.sublayer_norm,
|
sublayer_norm=cli_args.sublayer_norm,
|
||||||
split_heads=cli_args.split_heads,
|
split_heads=cli_args.split_heads,
|
||||||
init_scale=cli_args.init_scale,
|
|
||||||
reinit_lambda_init=cli_args.reinit_lambda_init,
|
|
||||||
)
|
)
|
||||||
model = LlamaDifferentialForCausalLM.from_llama(model, config)
|
model = LlamaDifferentialForCausalLM.from_llama(model, config)
|
||||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|||||||
@@ -54,10 +54,8 @@ class ConvertDiffTransformerCliArgs:
|
|||||||
|
|
||||||
debug: bool = field(default=False)
|
debug: bool = field(default=False)
|
||||||
zero_init: bool = field(default=False)
|
zero_init: bool = field(default=False)
|
||||||
sublayer_norm: bool = field(default=False)
|
sublayer_norm: bool = field(default=True)
|
||||||
split_heads: bool = field(default=False)
|
split_heads: bool = field(default=False)
|
||||||
init_scale: float = field(default=1e-6)
|
|
||||||
reinit_lambda_init: bool = field(default=True)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
"""Definition of differential transformer plugin."""
|
"""Definition of differential transformer plugin."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from transformers import PreTrainedModel, TrainerCallback
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.utils.callbacks.differential import DifferentialAttentionMonitorCallback
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -10,10 +15,41 @@ LOG = logging.getLogger(__name__)
|
|||||||
class DifferentialTransformerPlugin(BasePlugin):
|
class DifferentialTransformerPlugin(BasePlugin):
|
||||||
"""Plugin for differential transformer integration with Axolotl."""
|
"""Plugin for differential transformer integration with Axolotl."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Constructor for differential transformers plugin. Calls `register_diff_attn`
|
||||||
|
to register differential attention custom modeling implementation to `AutoConfig`
|
||||||
|
and `AutoModel`.
|
||||||
|
"""
|
||||||
from .modeling_diff_attn import register_diff_attn
|
from .modeling_diff_attn import register_diff_attn
|
||||||
|
|
||||||
register_diff_attn()
|
register_diff_attn()
|
||||||
|
|
||||||
def get_input_args(self):
|
def get_input_args(self) -> str:
|
||||||
|
"""Returns module path to diff transformer plugin args for `axolotl` config."""
|
||||||
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
|
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
|
||||||
|
|
||||||
|
def add_callbacks_pre_trainer(
|
||||||
|
self, cfg: DictDefault, model: PreTrainedModel
|
||||||
|
) -> List[TrainerCallback]:
|
||||||
|
"""
|
||||||
|
Returns `DifferentialAttentionMonitorCallback` to be added to the list of
|
||||||
|
callbacks for the `axolotl` trainer if wandb usage is enabled.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
model: The loaded mfodel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list (possibly) containing an instantiated `DifferentialAttentionMonitorCallback`.
|
||||||
|
"""
|
||||||
|
callbacks = []
|
||||||
|
if cfg.use_wandb:
|
||||||
|
callbacks.append(
|
||||||
|
DifferentialAttentionMonitorCallback(
|
||||||
|
log_every=cfg.diff_attn_log_every,
|
||||||
|
num_monitor_layers=cfg.diff_attn_num_monitor_layers,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return callbacks
|
||||||
|
|||||||
@@ -12,3 +12,5 @@ class DifferentialTransformerArgs(BaseModel):
|
|||||||
"""Input args for differential transformer."""
|
"""Input args for differential transformer."""
|
||||||
|
|
||||||
diff_attention: Optional[bool] = None
|
diff_attention: Optional[bool] = None
|
||||||
|
diff_attn_log_every: Optional[int] = 100
|
||||||
|
diff_attn_num_monitor_layers: Optional[int] = 3
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""Re-implemention of differential attention."""
|
"""Re-implemention of differential attention from the Differential Transformer paper
|
||||||
|
(https://arxiv.org/abs/2410.05258)."""
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -27,6 +28,18 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Repeats key/value heads to match the number of query heads in multi-head attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape `(batch_size, num_kv_heads, seq_len, head_dim)`.
|
||||||
|
n_rep: Number of times to repeat each head.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor with repeated heads of shape `(batch_size, num_kv_heads * n_rep,
|
||||||
|
seq_len, head_dim)`.
|
||||||
|
If `n_rep` is 1, returns the input tensor unchanged.
|
||||||
|
"""
|
||||||
batch_size, n_kv_heads, slen, head_dim = x.shape
|
batch_size, n_kv_heads, slen, head_dim = x.shape
|
||||||
if n_rep == 1:
|
if n_rep == 1:
|
||||||
return x
|
return x
|
||||||
@@ -37,14 +50,48 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def lambda_init_fn(depth):
|
def lambda_init_fn(depth: int) -> float:
|
||||||
|
"""
|
||||||
|
Lambda mixing parameter init function from the "Differential Transformer" paper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
depth: Index of layer to init lambda parameter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Lambda initialization value (decreasing with `depth`).
|
||||||
|
"""
|
||||||
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
||||||
|
|
||||||
|
|
||||||
class LlamaDifferentialAttentionBase(nn.Module):
|
class LlamaDifferentialAttentionBase(nn.Module):
|
||||||
"""Base class for differential attention implementations."""
|
"""
|
||||||
|
Base class for differential attention implementations.
|
||||||
|
|
||||||
def __init__(self, config: Any, layer_idx: int):
|
This class implements the core differential attention mechanism used in Llama models.
|
||||||
|
It supports both split heads and double projection modes for attention computation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Any, layer_idx: int) -> None:
|
||||||
|
"""
|
||||||
|
Initializes the differential attention module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Model configuration object containing hyperparameters, including:
|
||||||
|
- hidden_size: The size of hidden states
|
||||||
|
- num_attention_heads: Number of attention heads
|
||||||
|
- num_key_value_heads: Number of key/value heads
|
||||||
|
- attention_bias: Whether to use bias in attention projections
|
||||||
|
- split_heads: Whether to use split heads mode
|
||||||
|
- rms_norm_eps: Epsilon for RMS normalization
|
||||||
|
layer_idx: The index of this layer in the model
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The initialization process consists of four steps:
|
||||||
|
1. Configuration initialization (`_init_config`)
|
||||||
|
2. Projection layers initialization (`_init_projections`)
|
||||||
|
3. Differential parameters initialization (`_init_differential_params`)
|
||||||
|
4. Normalization layers initialization (`_init_normalization`)
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -53,8 +100,24 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
self._init_differential_params()
|
self._init_differential_params()
|
||||||
self._init_normalization()
|
self._init_normalization()
|
||||||
|
|
||||||
def _init_config(self, layer_idx: int):
|
# For logging
|
||||||
"""Initialize configuration parameters."""
|
self.attn1 = None
|
||||||
|
self.attn2 = None
|
||||||
|
self.lambda_full = None
|
||||||
|
|
||||||
|
def _init_config(self, layer_idx: int) -> None:
|
||||||
|
"""
|
||||||
|
Initializes configuration parameters for the attention layer. Sets up various
|
||||||
|
dimension sizes and head counts based on the provided config. Handles both
|
||||||
|
split heads and double projection modes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_idx: Index of the current layer.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
In split heads mode, the number of heads is divided by 2 (rounding down),
|
||||||
|
which differs from the original implementation that required an even number.
|
||||||
|
"""
|
||||||
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||||
self.base_num_heads = self.config.num_attention_heads
|
self.base_num_heads = self.config.num_attention_heads
|
||||||
self.base_num_kv_heads = self.config.num_key_value_heads
|
self.base_num_kv_heads = self.config.num_key_value_heads
|
||||||
@@ -62,26 +125,26 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
if self.config.split_heads:
|
if self.config.split_heads:
|
||||||
# Split heads mode - single projections
|
|
||||||
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
|
|
||||||
# implementation, which asserts `self.base_num_heads` is even
|
|
||||||
self.heads_per_component = self.base_num_heads // 2
|
self.heads_per_component = self.base_num_heads // 2
|
||||||
self.kv_heads_per_component = self.base_num_kv_heads // 2
|
self.kv_heads_per_component = self.base_num_kv_heads // 2
|
||||||
self.value_head_dim = 2 * self.head_dim
|
self.value_head_dim = 2 * self.head_dim
|
||||||
else:
|
else:
|
||||||
# Double projection mode
|
|
||||||
self.heads_per_component = self.base_num_heads
|
self.heads_per_component = self.base_num_heads
|
||||||
self.kv_heads_per_component = self.base_num_kv_heads
|
self.kv_heads_per_component = self.base_num_kv_heads
|
||||||
self.value_head_dim = self.head_dim
|
self.value_head_dim = self.head_dim
|
||||||
|
|
||||||
def _init_projections(self):
|
def _init_projections(self) -> None:
|
||||||
"""Initialize Q, K, V projections."""
|
"""
|
||||||
|
Initializes the query, key, value, and output projection layers.
|
||||||
|
|
||||||
|
Creates linear transformations for Q, K, V projections with dimensions
|
||||||
|
depending on whether split heads or double projection mode is used.
|
||||||
|
The output projection combines the attention heads back to model dimension.
|
||||||
|
"""
|
||||||
if self.config.split_heads:
|
if self.config.split_heads:
|
||||||
# Split heads mode - single projections
|
|
||||||
q_out_dim = self.config.hidden_size
|
q_out_dim = self.config.hidden_size
|
||||||
k_out_dim = self.head_dim * self.base_num_kv_heads
|
k_out_dim = self.head_dim * self.base_num_kv_heads
|
||||||
else:
|
else:
|
||||||
# Double projection mode
|
|
||||||
q_out_dim = self.config.hidden_size * 2
|
q_out_dim = self.config.hidden_size * 2
|
||||||
k_out_dim = self.head_dim * self.base_num_kv_heads * 2
|
k_out_dim = self.head_dim * self.base_num_kv_heads * 2
|
||||||
|
|
||||||
@@ -102,8 +165,15 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
bias=self.config.attention_bias,
|
bias=self.config.attention_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_differential_params(self):
|
def _init_differential_params(self) -> None:
|
||||||
"""Initialize differential attention parameters."""
|
"""
|
||||||
|
Initializes parameters specific to differential attention.
|
||||||
|
|
||||||
|
Creates learnable parameters for the differential attention mechanism:
|
||||||
|
- Lambda parameters for queries and keys
|
||||||
|
- Initial lambda value based on layer index
|
||||||
|
- Rotary position embedding layer
|
||||||
|
"""
|
||||||
self.lambda_init = nn.Parameter(
|
self.lambda_init = nn.Parameter(
|
||||||
torch.full((), lambda_init_fn(self.layer_idx)),
|
torch.full((), lambda_init_fn(self.layer_idx)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
@@ -122,19 +192,42 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
)
|
)
|
||||||
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
||||||
|
|
||||||
def _init_normalization(self):
|
def _init_normalization(self) -> None:
|
||||||
"""Initialize normalization layers."""
|
"""
|
||||||
|
Initializes normalization layers for the attention mechanism.
|
||||||
|
|
||||||
|
Sets up either RMS normalization or identity transformation based on config.
|
||||||
|
The normalization is applied to the sublayer output if enabled.
|
||||||
|
"""
|
||||||
sublayer_norm = getattr(self.config, "sublayer_norm", True)
|
sublayer_norm = getattr(self.config, "sublayer_norm", True)
|
||||||
if sublayer_norm:
|
if sublayer_norm:
|
||||||
self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps)
|
self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps)
|
||||||
else:
|
else:
|
||||||
self.subln = nn.Identity()
|
self.subln = nn.Identity()
|
||||||
|
|
||||||
def _prepare_attention_inputs(self, hidden_states: torch.Tensor):
|
def _prepare_attention_inputs(
|
||||||
"""Prepare inputs for attention computation."""
|
self, hidden_states: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Prepares input tensors for attention computation.
|
||||||
|
|
||||||
|
Projects input hidden states to query, key, and value spaces, then reshapes
|
||||||
|
them for multi-head attention processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states: Input tensor of shape `(batch_size, seq_len,
|
||||||
|
hidden_size)`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Tuple containing:
|
||||||
|
- q1: Positive attention query component
|
||||||
|
- q2: Negative attention query component
|
||||||
|
- k1: Positive attention key component
|
||||||
|
- k2: Negative attention key component
|
||||||
|
- v: Value tensor
|
||||||
|
"""
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
# Project and split
|
|
||||||
q = self.q_proj(hidden_states)
|
q = self.q_proj(hidden_states)
|
||||||
k = self.k_proj(hidden_states)
|
k = self.k_proj(hidden_states)
|
||||||
v = self.v_proj(hidden_states)
|
v = self.v_proj(hidden_states)
|
||||||
@@ -158,9 +251,41 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
return q1, q2, k1, k2, v
|
return q1, q2, k1, k2, v
|
||||||
|
|
||||||
def _apply_rotary_embeddings(
|
def _apply_rotary_embeddings(
|
||||||
self, q1, q2, k1, k2, position_ids, position_embeddings
|
self,
|
||||||
):
|
q1: torch.Tensor,
|
||||||
"""Apply rotary embeddings to queries and keys."""
|
q2: torch.Tensor,
|
||||||
|
k1: torch.Tensor,
|
||||||
|
k2: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None,
|
||||||
|
) -> tuple[
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Applies rotary positional embeddings to queries and keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q1: Positive attention query component.
|
||||||
|
q2: Negative attention query component.
|
||||||
|
k1: Positive attention key component.
|
||||||
|
k2: Negative attention key component.
|
||||||
|
position_ids: Token position indices.
|
||||||
|
position_embeddings: Pre-computed rotary embeddings (cos, sin).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Tuple containing:
|
||||||
|
- q1: Positive attention query with positional encoding.
|
||||||
|
- q2: Negative attention query with positional encoding.
|
||||||
|
- k1: Positive attention key with positional encoding.
|
||||||
|
- k2: Negative attention key with positional encoding.
|
||||||
|
- cos: Cosine part of rotary embeddings.
|
||||||
|
- sin: Sine part of rotary embeddings.
|
||||||
|
"""
|
||||||
if position_embeddings is None:
|
if position_embeddings is None:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
@@ -177,14 +302,36 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
|
|
||||||
return q1, q2, k1, k2, cos, sin
|
return q1, q2, k1, k2, cos, sin
|
||||||
|
|
||||||
def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs):
|
def _handle_cache(
|
||||||
"""Handle caching for autoregressive generation."""
|
self,
|
||||||
|
k1: torch.Tensor,
|
||||||
|
k2: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
past_key_value: Cache | None,
|
||||||
|
cache_kwargs: dict,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Handles key-value caching for autoregressive generation and the repetition of
|
||||||
|
key-value heads to match the number of query heads.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k1: Positive attention key component.
|
||||||
|
k2: Negative attention key component.
|
||||||
|
v: Value tensor.
|
||||||
|
past_key_value: Cache object for storing previous key-value pairs.
|
||||||
|
cache_kwargs: Additional arguments for cache handling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Tuple containing:
|
||||||
|
- k1: Processed positive attention key component.
|
||||||
|
- k2: Processed negative attention key component.
|
||||||
|
- v: Processed value tensor.
|
||||||
|
"""
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
k = torch.stack([k1, k2], dim=1)
|
k = torch.stack([k1, k2], dim=1)
|
||||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||||
k1, k2 = k.unbind(dim=1)
|
k1, k2 = k.unbind(dim=1)
|
||||||
|
|
||||||
# Repeat KV heads to match number of query heads
|
|
||||||
k1 = repeat_kv(k1, self.num_key_value_groups)
|
k1 = repeat_kv(k1, self.num_key_value_groups)
|
||||||
k2 = repeat_kv(k2, self.num_key_value_groups)
|
k2 = repeat_kv(k2, self.num_key_value_groups)
|
||||||
v = repeat_kv(v, self.num_key_value_groups)
|
v = repeat_kv(v, self.num_key_value_groups)
|
||||||
@@ -193,39 +340,90 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
|
|
||||||
return k1, k2, v
|
return k1, k2, v
|
||||||
|
|
||||||
def _compute_lambda(self, q1):
|
def _compute_lambda(self, q1: torch.Tensor) -> torch.Tensor:
|
||||||
"""Compute lambda values for differential attention."""
|
"""
|
||||||
|
Computes lambda values for differential attention.
|
||||||
|
|
||||||
|
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are
|
||||||
|
computed from the learned parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q1: Positive attention query component, used for type casting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Computed lambda value for differential attention.
|
||||||
|
"""
|
||||||
lambda_1 = torch.exp(
|
lambda_1 = torch.exp(
|
||||||
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
||||||
).type_as(q1)
|
).type_as(q1)
|
||||||
lambda_2 = torch.exp(
|
lambda_2 = torch.exp(
|
||||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||||
).type_as(q1)
|
).type_as(q1)
|
||||||
|
|
||||||
return lambda_1 - lambda_2 + self.lambda_init
|
return lambda_1 - lambda_2 + self.lambda_init
|
||||||
|
|
||||||
def _process_attention_output(self, attn, bsz, q_len):
|
def _process_attention_output(
|
||||||
"""Process and project attention output."""
|
self, attn: torch.Tensor, bsz: int, q_len: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Processes and projects the attention output. Applies sublayer normalization,
|
||||||
|
scales by (1 - λ_init), and projects back to model dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attn: Raw attention output.
|
||||||
|
bsz: Batch size.
|
||||||
|
q_len: Query sequence length.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed attention output of shape (batch_size, seq_len, hidden_size)
|
||||||
|
"""
|
||||||
attn = self.subln(attn)
|
attn = self.subln(attn)
|
||||||
attn = attn * (1 - self.lambda_init)
|
attn = attn * (1 - self.lambda_init)
|
||||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
|
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
|
||||||
|
|
||||||
return self.o_proj(attn)
|
return self.o_proj(attn)
|
||||||
|
|
||||||
|
|
||||||
class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
|
class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
|
||||||
"""Standard implementation of differential attention."""
|
"""
|
||||||
|
Standard implementation of differential attention.
|
||||||
|
|
||||||
|
This class implements the standard differential attention mechanism using
|
||||||
|
explicit matrix multiplications for the attention computation.
|
||||||
|
"""
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: torch.Tensor | None = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: torch.LongTensor | None = None,
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Cache | None = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False, # pylint: disable=unused-argument
|
use_cache: bool = False, # pylint: disable=unused-argument
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: torch.LongTensor | None = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Computes differential attention using standard matrix multiplication operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states: Input tensor containing sequence to attend to.
|
||||||
|
attention_mask: Mask to avoid attention on padding tokens.
|
||||||
|
position_ids: Indices of positions for positional embeddings.
|
||||||
|
past_key_value: Cached key and value tensors for autoregressive decoding.
|
||||||
|
output_attentions: Whether to return attention weights.
|
||||||
|
use_cache: Whether to use cached key/value states.
|
||||||
|
cache_position: Position indices for cached states.
|
||||||
|
position_embeddings: Pre-computed positional embeddings.
|
||||||
|
**kwargs: Additional arguments passed to the forward call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple containing:
|
||||||
|
- Output tensor after attention computation.
|
||||||
|
- Attention weights if output_attentions is True, else None.
|
||||||
|
- Updated key-value cache if use_cache is True, else None.
|
||||||
|
"""
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
|
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
|
||||||
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
|
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
|
||||||
@@ -255,6 +453,11 @@ class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
|
|||||||
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
|
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
|
||||||
attn = self._process_attention_output(attn, bsz, q_len)
|
attn = self._process_attention_output(attn, bsz, q_len)
|
||||||
|
|
||||||
|
# Save for logging
|
||||||
|
self.attn1 = attn1
|
||||||
|
self.attn2 = attn2
|
||||||
|
self.lambda_full = lambda_full
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
attn_weights = attn1 - lambda_full * attn2
|
attn_weights = attn1 - lambda_full * attn2
|
||||||
attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1)
|
attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1)
|
||||||
@@ -263,27 +466,53 @@ class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
|
|||||||
|
|
||||||
|
|
||||||
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
|
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
|
||||||
"""SDPA-based implementation of differential attention."""
|
"""
|
||||||
|
SDPA-based implementation of differential attention.
|
||||||
|
|
||||||
|
This class implements differential attention using PyTorch's scaled_dot_product_attention
|
||||||
|
for improved performance on supported hardware.
|
||||||
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: torch.Tensor | None = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: torch.LongTensor | None = None,
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Cache | None = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: torch.LongTensor | None = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Computes differential attention using PyTorch's scaled dot product attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states: Input tensor containing sequence to attend to.
|
||||||
|
attention_mask: Mask to avoid attention on padding tokens.
|
||||||
|
position_ids: Indices of positions for positional embeddings.
|
||||||
|
past_key_value: Cached key and value tensors for autoregressive decoding.
|
||||||
|
output_attentions: Whether to return attention weights.
|
||||||
|
use_cache: Whether to use cached key/value states.
|
||||||
|
cache_position: Position indices for cached states.
|
||||||
|
position_embeddings: Pre-computed positional embeddings.
|
||||||
|
**kwargs: Additional arguments passed to the forward call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple containing:
|
||||||
|
- Output tensor after attention computation.
|
||||||
|
- None for attention weights (SDPA doesn't support output_attentions).
|
||||||
|
- Updated key-value cache if use_cache is True, else None.
|
||||||
|
"""
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but "
|
"LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but "
|
||||||
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
|
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||||
+ "`output_attentions=True`. Falling back to the eager attention implementation."
|
+ "`output_attentions=True`. Falling back to the eager attention implementation."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
return LlamaDifferentialAttention.forward(
|
return LlamaDifferentialAttention.forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -326,15 +555,35 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
|
|||||||
|
|
||||||
lambda_full = self._compute_lambda(q1)
|
lambda_full = self._compute_lambda(q1)
|
||||||
attn = attn1 - lambda_full * attn2
|
attn = attn1 - lambda_full * attn2
|
||||||
|
|
||||||
attn = self._process_attention_output(attn, bsz, q_len)
|
attn = self._process_attention_output(attn, bsz, q_len)
|
||||||
|
|
||||||
|
# Save for logging
|
||||||
|
self.attn1 = attn1
|
||||||
|
self.attn2 = attn2
|
||||||
|
self.lambda_full = lambda_full
|
||||||
|
|
||||||
return attn, None, past_key_value
|
return attn, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
|
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
|
||||||
"""Flash Attention 2-based implementation of differential attention."""
|
"""
|
||||||
|
Flash Attention 2-based implementation of differential attention.
|
||||||
|
|
||||||
|
This class implements differential attention using Flash Attention 2 for maximum
|
||||||
|
performance on supported hardware.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Initializes the Flash Attention 2 differential attention module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Positional arguments passed to parent class.
|
||||||
|
**kwargs: Keyword arguments passed to parent class.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If flash-attn library is not installed.
|
||||||
|
"""
|
||||||
if not FLASH_ATTENTION_AVAILABLE:
|
if not FLASH_ATTENTION_AVAILABLE:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"LlamaDifferentialFlashAttention2 requires flash-attn library. "
|
"LlamaDifferentialFlashAttention2 requires flash-attn library. "
|
||||||
@@ -343,25 +592,46 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
|
|||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: torch.Tensor | None = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: torch.LongTensor | None = None,
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Cache | None = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: torch.LongTensor | None = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Computes differential attention using Flash Attention 2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states: Input tensor containing sequence to attend to.
|
||||||
|
attention_mask: Mask to avoid attention on padding tokens.
|
||||||
|
position_ids: Indices of positions for positional embeddings.
|
||||||
|
past_key_value: Cached key and value tensors for autoregressive decoding.
|
||||||
|
output_attentions: Whether to return attention weights.
|
||||||
|
use_cache: Whether to use cached key/value states.
|
||||||
|
cache_position: Position indices for cached states.
|
||||||
|
position_embeddings: Pre-computed positional embeddings.
|
||||||
|
**kwargs: Additional arguments passed to the forward call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple containing:
|
||||||
|
- Output tensor after attention computation.
|
||||||
|
- None for attention weights (Flash Attention doesn't support output_attentions).
|
||||||
|
- Updated key-value cache if use_cache is True, else None.
|
||||||
|
"""
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but "
|
"LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but "
|
||||||
+ "flash attenion does not support `output_attentions=True`. Falling back "
|
+ "flash attenion does not support `output_attentions=True`. Falling back "
|
||||||
+ "to the eager attention implementation."
|
+ "to the eager attention implementation."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
return LlamaDifferentialAttention.forward(
|
return LlamaDifferentialAttention.forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -407,6 +677,11 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
|
|||||||
|
|
||||||
lambda_full = self._compute_lambda(q1)
|
lambda_full = self._compute_lambda(q1)
|
||||||
attn = attn1 - lambda_full * attn2
|
attn = attn1 - lambda_full * attn2
|
||||||
|
|
||||||
attn = self._process_attention_output(attn, bsz, q_len)
|
attn = self._process_attention_output(attn, bsz, q_len)
|
||||||
|
|
||||||
|
# Save for logging
|
||||||
|
self.attn1 = attn1
|
||||||
|
self.attn2 = attn2
|
||||||
|
self.lambda_full = lambda_full
|
||||||
|
|
||||||
return attn, None, past_key_value
|
return attn, None, past_key_value
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
"""Modeling for differential transformers."""
|
"""
|
||||||
|
Modeling for differential transformers.
|
||||||
|
|
||||||
|
This module implements differential attention variants of the LLaMA model,
|
||||||
|
providing various attention implementations for improved performance.
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||||
@@ -18,7 +22,12 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class LlamaDifferentialConfig(LlamaConfig):
|
class LlamaDifferentialConfig(LlamaConfig):
|
||||||
"""Configuration class for Differential LLaMA model."""
|
"""
|
||||||
|
Configuration class for Differential LLaMA model.
|
||||||
|
|
||||||
|
Extends the base LLaMA configuration with additional parameters for differential
|
||||||
|
attention mechanisms.
|
||||||
|
"""
|
||||||
|
|
||||||
model_type = "llama-differential"
|
model_type = "llama-differential"
|
||||||
|
|
||||||
@@ -29,6 +38,15 @@ class LlamaDifferentialConfig(LlamaConfig):
|
|||||||
zero_init: bool = False,
|
zero_init: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize differential LLaMA configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
split_heads: Whether to use split heads mode for attention computation.
|
||||||
|
sublayer_norm: Whether to apply normalization to sublayers.
|
||||||
|
zero_init: Whether to initialize new weights to zero.
|
||||||
|
**kwargs: Additional arguments passed to LlamaConfig.
|
||||||
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.split_heads = split_heads
|
self.split_heads = split_heads
|
||||||
self.sublayer_norm = sublayer_norm
|
self.sublayer_norm = sublayer_norm
|
||||||
@@ -42,12 +60,26 @@ class LlamaDifferentialConfig(LlamaConfig):
|
|||||||
|
|
||||||
|
|
||||||
class LlamaDifferentialModel(LlamaModel):
|
class LlamaDifferentialModel(LlamaModel):
|
||||||
"""LlamaModel with differential attention."""
|
"""
|
||||||
|
LlamaModel with differential attention.
|
||||||
|
|
||||||
|
This class extends the base LLaMA model by replacing standard attention with
|
||||||
|
differential attention mechanisms.
|
||||||
|
"""
|
||||||
|
|
||||||
config_class = LlamaDifferentialConfig
|
config_class = LlamaDifferentialConfig
|
||||||
base_model_prefix = "llama_differential"
|
base_model_prefix = "llama_differential"
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config: LlamaDifferentialConfig):
|
||||||
|
"""
|
||||||
|
Initialize a differential LLaMA model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration object for the model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If specified attention implementation is not supported.
|
||||||
|
"""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
# Handle attention implementation
|
# Handle attention implementation
|
||||||
@@ -76,11 +108,26 @@ class LlamaDifferentialModel(LlamaModel):
|
|||||||
for idx, layer in enumerate(self.layers):
|
for idx, layer in enumerate(self.layers):
|
||||||
layer.self_attn = attn_class(config, idx)
|
layer.self_attn = attn_class(config, idx)
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
# pylint: disable=protected-access
|
||||||
def _autoset_attn_implementation(
|
def _autoset_attn_implementation(
|
||||||
cls, config, **kwargs
|
cls,
|
||||||
): # pylint: disable=unused-argument
|
config: LlamaDifferentialConfig,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
) -> LlamaDifferentialConfig:
|
||||||
|
"""
|
||||||
|
Automatically set the attention implementation based on config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Model configuration object.
|
||||||
|
**kwargs: Additional arguments (unused).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated configuration object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If specified attention implementation is not supported.
|
||||||
|
"""
|
||||||
config._attn_implementation_autoset = True
|
config._attn_implementation_autoset = True
|
||||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||||
|
|
||||||
@@ -110,10 +157,23 @@ class LlamaDifferentialModel(LlamaModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_llama(
|
def from_llama(
|
||||||
cls,
|
cls,
|
||||||
model: Union[LlamaModel, LlamaForCausalLM],
|
model: LlamaModel | LlamaForCausalLM,
|
||||||
config: Optional[LlamaDifferentialConfig] = None,
|
config: LlamaDifferentialConfig | None = None,
|
||||||
) -> "LlamaDifferentialModel":
|
) -> "LlamaDifferentialModel":
|
||||||
"""Convert a LlamaModel to use differential attention."""
|
"""
|
||||||
|
Convert a `LlamaModel` to use differential attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Base LLaMA model to convert.
|
||||||
|
config: Configuration for differential attention. If `None`, created from
|
||||||
|
base model config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Converted model with differential attention.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If number of heads is not even when using `split_heads` mode.
|
||||||
|
"""
|
||||||
logger.info(f"Converting {type(model).__name__} to {cls.__name__}")
|
logger.info(f"Converting {type(model).__name__} to {cls.__name__}")
|
||||||
|
|
||||||
# Handle LlamaForCausalLM
|
# Handle LlamaForCausalLM
|
||||||
@@ -182,7 +242,6 @@ class LlamaDifferentialModel(LlamaModel):
|
|||||||
|
|
||||||
if config.zero_init:
|
if config.zero_init:
|
||||||
logger.debug(f"Layer {layer_idx}: Zero initializing")
|
logger.debug(f"Layer {layer_idx}: Zero initializing")
|
||||||
# Zero out components as needed
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
|
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
|
||||||
new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_()
|
new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_()
|
||||||
@@ -192,45 +251,60 @@ class LlamaDifferentialModel(LlamaModel):
|
|||||||
new_layer.self_attn.lambda_k2.zero_()
|
new_layer.self_attn.lambda_k2.zero_()
|
||||||
new_layer.self_attn.lambda_init.zero_()
|
new_layer.self_attn.lambda_init.zero_()
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
# Mirror weights for second component
|
||||||
f"Layer {layer_idx}: Initializing with scale {config.init_scale}"
|
new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_(
|
||||||
|
old_layer.self_attn.q_proj.weight.data
|
||||||
|
)
|
||||||
|
new_layer.self_attn.k_proj.weight.data[old_k_size:].copy_(
|
||||||
|
old_layer.self_attn.k_proj.weight.data
|
||||||
)
|
)
|
||||||
# Initialize with small random values
|
|
||||||
with torch.no_grad():
|
|
||||||
new_layer.self_attn.q_proj.weight.data[old_q_size:].normal_(
|
|
||||||
0, config.init_scale
|
|
||||||
)
|
|
||||||
new_layer.self_attn.k_proj.weight.data[old_k_size:].normal_(
|
|
||||||
0, config.init_scale
|
|
||||||
)
|
|
||||||
new_layer.self_attn.lambda_q1.normal_(0, config.init_scale)
|
|
||||||
new_layer.self_attn.lambda_k1.normal_(0, config.init_scale)
|
|
||||||
new_layer.self_attn.lambda_q2.normal_(0, config.init_scale)
|
|
||||||
new_layer.self_attn.lambda_k2.normal_(0, config.init_scale)
|
|
||||||
if config.reinit_lambda_init:
|
|
||||||
new_layer.self_attn.lambda_init.normal_(
|
|
||||||
0, config.init_scale
|
|
||||||
).abs_()
|
|
||||||
|
|
||||||
logger.info("Conversion complete")
|
logger.info("Conversion complete")
|
||||||
|
|
||||||
return new_model
|
return new_model
|
||||||
|
|
||||||
|
|
||||||
class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
||||||
"""LlamaForCausalLM with differential attention."""
|
"""
|
||||||
|
`LlamaForCausalLM` with differential attention.
|
||||||
|
|
||||||
|
This class extends the base LLaMA causal language model by incorporating
|
||||||
|
differential attention mechanisms.
|
||||||
|
"""
|
||||||
|
|
||||||
config_class = LlamaDifferentialConfig
|
config_class = LlamaDifferentialConfig
|
||||||
base_model_prefix = "llama_differential"
|
base_model_prefix = "llama_differential"
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config: LlamaDifferentialConfig):
|
||||||
|
"""
|
||||||
|
Initialize a differential LLaMA model for causal language modeling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration object for the model.
|
||||||
|
"""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.model = LlamaDifferentialModel(config)
|
self.model = LlamaDifferentialModel(config)
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
# pylint: disable=protected-access
|
||||||
def _autoset_attn_implementation(
|
def _autoset_attn_implementation(
|
||||||
cls, config, **kwargs
|
cls,
|
||||||
): # pylint: disable=unused-argument
|
config: LlamaDifferentialConfig,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
) -> LlamaDifferentialConfig:
|
||||||
|
"""
|
||||||
|
Automatically set the attention implementation based on config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Model configuration object.
|
||||||
|
**kwargs: Additional arguments (unused).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated configuration object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If specified attention implementation is not supported.
|
||||||
|
"""
|
||||||
config._attn_implementation_autoset = True
|
config._attn_implementation_autoset = True
|
||||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||||
|
|
||||||
@@ -239,6 +313,7 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
|||||||
config._attn_implementation = config._attn_implementations[
|
config._attn_implementation = config._attn_implementations[
|
||||||
attn_implementation
|
attn_implementation
|
||||||
]
|
]
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
# If no mapping, validate it's a valid differential type
|
# If no mapping, validate it's a valid differential type
|
||||||
@@ -259,9 +334,22 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llama(
|
def from_llama(
|
||||||
cls, model: LlamaForCausalLM, config: Optional[LlamaDifferentialConfig] = None
|
cls, model: LlamaForCausalLM, config: LlamaDifferentialConfig | None = None
|
||||||
) -> "LlamaDifferentialForCausalLM":
|
) -> "LlamaDifferentialForCausalLM":
|
||||||
"""Convert a LlamaForCausalLM to use differential attention."""
|
"""
|
||||||
|
Convert a `LlamaForCausalLM` to use differential attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Base LLaMA model to convert.
|
||||||
|
config: Configuration for differential attention. If `None`, created from
|
||||||
|
base model config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Converted model with differential attention.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If number of heads is not even when using `split_heads` mode.
|
||||||
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = LlamaDifferentialConfig(**model.config.__dict__)
|
config = LlamaDifferentialConfig(**model.config.__dict__)
|
||||||
|
|
||||||
@@ -285,7 +373,14 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
|||||||
return new_model
|
return new_model
|
||||||
|
|
||||||
|
|
||||||
def register_diff_attn():
|
def register_diff_attn() -> None:
|
||||||
|
"""
|
||||||
|
Register differential attention components with the transformers library.
|
||||||
|
|
||||||
|
This function registers the differential attention configurations and model classes
|
||||||
|
with the Auto* classes from `transformers`, making them available through the
|
||||||
|
standard model loading pipeline.
|
||||||
|
"""
|
||||||
# Register configs
|
# Register configs
|
||||||
AutoConfig.register("llama-differential", LlamaDifferentialConfig)
|
AutoConfig.register("llama-differential", LlamaDifferentialConfig)
|
||||||
|
|
||||||
|
|||||||
171
src/axolotl/utils/callbacks/differential.py
Normal file
171
src/axolotl/utils/callbacks/differential.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""
|
||||||
|
Monitor and log differential attention components during training.
|
||||||
|
|
||||||
|
This module provides a callback for tracking the behavior of differential attention
|
||||||
|
mechanisms, including lambda parameters and attention statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import wandb
|
||||||
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
|
|
||||||
|
class DifferentialAttentionMonitorCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
Callback to monitor differential attention components and lambda parameters.
|
||||||
|
|
||||||
|
This callback tracks attention statistics across all layers and provides detailed
|
||||||
|
monitoring for a specified number of layers evenly spaced through the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, log_every: int = 250, num_monitor_layers: int = 3):
|
||||||
|
"""
|
||||||
|
Initialize the differential attention monitor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_every: Number of steps between logging events.
|
||||||
|
num_monitor_layers: Number of individual layers to monitor in detail.
|
||||||
|
"""
|
||||||
|
self.log_every = log_every
|
||||||
|
self.num_monitor_layers = num_monitor_layers
|
||||||
|
self.monitor_layers: list[int] | None = None # Will be set in on_train_begin
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: Any,
|
||||||
|
state: Any,
|
||||||
|
control: Any,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set up layer monitoring at the start of training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Training arguments.
|
||||||
|
state: Training state.
|
||||||
|
control: Training control object.
|
||||||
|
model: The model being trained.
|
||||||
|
**kwargs: Additional arguments passed by the trainer.
|
||||||
|
"""
|
||||||
|
if is_main_process():
|
||||||
|
num_layers = len(model.model.layers)
|
||||||
|
self.num_monitor_layers = min(self.num_monitor_layers, num_layers)
|
||||||
|
|
||||||
|
stride = (
|
||||||
|
(num_layers - 1) / (self.num_monitor_layers - 1)
|
||||||
|
if self.num_monitor_layers > 1
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
self.monitor_layers = [
|
||||||
|
round(i * stride) for i in range(self.num_monitor_layers)
|
||||||
|
]
|
||||||
|
print(f"Monitoring layers {self.monitor_layers} in detail")
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def on_step_end(
|
||||||
|
self, args: Any, state: Any, control: Any, model: torch.nn.Module, **kwargs
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Log attention metrics at the end of each step.
|
||||||
|
|
||||||
|
Collects and logs:
|
||||||
|
- Lambda parameter norms and values.
|
||||||
|
- Attention statistics (mean and std).
|
||||||
|
- Both per-layer and aggregate metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Training arguments.
|
||||||
|
state: Training state.
|
||||||
|
control: Training control object.
|
||||||
|
model: The model being trained.
|
||||||
|
**kwargs: Additional arguments passed by the trainer.
|
||||||
|
"""
|
||||||
|
if not is_main_process() or state.global_step % self.log_every != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert self.monitor_layers is not None
|
||||||
|
|
||||||
|
# Aggregate stats across all layers
|
||||||
|
all_q1_norms = []
|
||||||
|
all_q2_norms = []
|
||||||
|
all_k1_norms = []
|
||||||
|
all_k2_norms = []
|
||||||
|
all_lambda1 = []
|
||||||
|
all_lambda2 = []
|
||||||
|
all_lambda_full = []
|
||||||
|
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
for layer_idx, layer in enumerate(model.model.layers):
|
||||||
|
attn = layer.self_attn
|
||||||
|
|
||||||
|
# Collect stats for aggregation
|
||||||
|
all_q1_norms.append(attn.lambda_q1.norm().item())
|
||||||
|
all_q2_norms.append(attn.lambda_q2.norm().item())
|
||||||
|
all_k1_norms.append(attn.lambda_k1.norm().item())
|
||||||
|
all_k2_norms.append(attn.lambda_k2.norm().item())
|
||||||
|
|
||||||
|
lambda1 = torch.exp(torch.sum(attn.lambda_q1 * attn.lambda_k1)).item()
|
||||||
|
lambda2 = torch.exp(torch.sum(attn.lambda_q2 * attn.lambda_k2)).item()
|
||||||
|
all_lambda1.append(lambda1)
|
||||||
|
all_lambda2.append(lambda2)
|
||||||
|
all_lambda_full.append(attn.lambda_full)
|
||||||
|
|
||||||
|
# Log detailed metrics for monitored layers
|
||||||
|
if layer_idx in self.monitor_layers:
|
||||||
|
metrics.update(
|
||||||
|
{
|
||||||
|
f"layer_{layer_idx}/lambda_q1_norm": attn.lambda_q1.norm().item(),
|
||||||
|
f"layer_{layer_idx}/lambda_k1_norm": attn.lambda_k1.norm().item(),
|
||||||
|
f"layer_{layer_idx}/lambda_q2_norm": attn.lambda_q2.norm().item(),
|
||||||
|
f"layer_{layer_idx}/lambda_k2_norm": attn.lambda_k2.norm().item(),
|
||||||
|
f"layer_{layer_idx}/lambda1": lambda1,
|
||||||
|
f"layer_{layer_idx}/lambda2": lambda2,
|
||||||
|
f"layer_{layer_idx}/lambda_init": attn.lambda_init.item(),
|
||||||
|
f"layer_{layer_idx}/lambda_full": lambda1
|
||||||
|
- lambda2
|
||||||
|
+ attn.lambda_init.item(),
|
||||||
|
f"layer_{layer_idx}/attn1_mean": attn.attn1.mean().item(),
|
||||||
|
f"layer_{layer_idx}/attn2_mean": attn.attn2.mean().item(),
|
||||||
|
f"layer_{layer_idx}/attn1_std": attn.attn1.std().item(),
|
||||||
|
f"layer_{layer_idx}/attn2_std": attn.attn2.std().item(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add aggregate metrics
|
||||||
|
metrics.update(
|
||||||
|
{
|
||||||
|
"aggregate/lambda_q1_norm_mean": torch.tensor(all_q1_norms)
|
||||||
|
.mean()
|
||||||
|
.item(),
|
||||||
|
"aggregate/lambda_q1_norm_std": torch.tensor(all_q1_norms).std().item(),
|
||||||
|
"aggregate/lambda_q2_norm_mean": torch.tensor(all_q2_norms)
|
||||||
|
.mean()
|
||||||
|
.item(),
|
||||||
|
"aggregate/lambda_q2_norm_std": torch.tensor(all_q2_norms).std().item(),
|
||||||
|
"aggregate/lambda_k1_norm_mean": torch.tensor(all_k1_norms)
|
||||||
|
.mean()
|
||||||
|
.item(),
|
||||||
|
"aggregate/lambda_k1_norm_std": torch.tensor(all_k1_norms).std().item(),
|
||||||
|
"aggregate/lambda_k2_norm_mean": torch.tensor(all_k2_norms)
|
||||||
|
.mean()
|
||||||
|
.item(),
|
||||||
|
"aggregate/lambda_k2_norm_std": torch.tensor(all_k2_norms).std().item(),
|
||||||
|
"aggregate/lambda1_mean": torch.tensor(all_lambda1).mean().item(),
|
||||||
|
"aggregate/lambda1_std": torch.tensor(all_lambda1).std().item(),
|
||||||
|
"aggregate/lambda2_mean": torch.tensor(all_lambda2).mean().item(),
|
||||||
|
"aggregate/lambda2_std": torch.tensor(all_lambda2).std().item(),
|
||||||
|
"aggregate/lambda_full_mean": torch.tensor(all_lambda_full)
|
||||||
|
.mean()
|
||||||
|
.item(),
|
||||||
|
"aggregate/lambda_full_std": torch.tensor(all_lambda_full).std().item(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
wandb.log(metrics, step=state.global_step)
|
||||||
@@ -69,7 +69,7 @@ def test_conversion_cli_debug(tmp_path: Path, base_config):
|
|||||||
yaml.dump(base_config, file)
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
cfg = load_cfg(str(config_path))
|
cfg = load_cfg(str(config_path))
|
||||||
cli_args = ConvertDiffTransformerCliArgs(debug=True, init_scale=0.1)
|
cli_args = ConvertDiffTransformerCliArgs(debug=True)
|
||||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||||
|
|
||||||
assert not debug_info["generations_match"]
|
assert not debug_info["generations_match"]
|
||||||
|
|||||||
Reference in New Issue
Block a user