adding diff attn callback, adding documentation
This commit is contained in:
@@ -202,7 +202,7 @@ def do_inference(
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = get_chat_template(cfg.chat_template)
|
||||
elif cfg.datasets[0].type == "chat_template":
|
||||
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
|
||||
chat_template_str = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||
)
|
||||
|
||||
@@ -87,8 +87,6 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
zero_init=cli_args.zero_init,
|
||||
sublayer_norm=cli_args.sublayer_norm,
|
||||
split_heads=cli_args.split_heads,
|
||||
init_scale=cli_args.init_scale,
|
||||
reinit_lambda_init=cli_args.reinit_lambda_init,
|
||||
)
|
||||
model = LlamaDifferentialForCausalLM.from_llama(model, config)
|
||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
@@ -54,10 +54,8 @@ class ConvertDiffTransformerCliArgs:
|
||||
|
||||
debug: bool = field(default=False)
|
||||
zero_init: bool = field(default=False)
|
||||
sublayer_norm: bool = field(default=False)
|
||||
sublayer_norm: bool = field(default=True)
|
||||
split_heads: bool = field(default=False)
|
||||
init_scale: float = field(default=1e-6)
|
||||
reinit_lambda_init: bool = field(default=True)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
"""Definition of differential transformer plugin."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from transformers import PreTrainedModel, TrainerCallback
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.callbacks.differential import DifferentialAttentionMonitorCallback
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,10 +15,41 @@ LOG = logging.getLogger(__name__)
|
||||
class DifferentialTransformerPlugin(BasePlugin):
|
||||
"""Plugin for differential transformer integration with Axolotl."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Constructor for differential transformers plugin. Calls `register_diff_attn`
|
||||
to register differential attention custom modeling implementation to `AutoConfig`
|
||||
and `AutoModel`.
|
||||
"""
|
||||
from .modeling_diff_attn import register_diff_attn
|
||||
|
||||
register_diff_attn()
|
||||
|
||||
def get_input_args(self):
|
||||
def get_input_args(self) -> str:
|
||||
"""Returns module path to diff transformer plugin args for `axolotl` config."""
|
||||
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
|
||||
|
||||
def add_callbacks_pre_trainer(
|
||||
self, cfg: DictDefault, model: PreTrainedModel
|
||||
) -> List[TrainerCallback]:
|
||||
"""
|
||||
Returns `DifferentialAttentionMonitorCallback` to be added to the list of
|
||||
callbacks for the `axolotl` trainer if wandb usage is enabled.
|
||||
|
||||
Parameters:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
model: The loaded mfodel.
|
||||
|
||||
Returns:
|
||||
A list (possibly) containing an instantiated `DifferentialAttentionMonitorCallback`.
|
||||
"""
|
||||
callbacks = []
|
||||
if cfg.use_wandb:
|
||||
callbacks.append(
|
||||
DifferentialAttentionMonitorCallback(
|
||||
log_every=cfg.diff_attn_log_every,
|
||||
num_monitor_layers=cfg.diff_attn_num_monitor_layers,
|
||||
)
|
||||
)
|
||||
|
||||
return callbacks
|
||||
|
||||
@@ -12,3 +12,5 @@ class DifferentialTransformerArgs(BaseModel):
|
||||
"""Input args for differential transformer."""
|
||||
|
||||
diff_attention: Optional[bool] = None
|
||||
diff_attn_log_every: Optional[int] = 100
|
||||
diff_attn_num_monitor_layers: Optional[int] = 3
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Re-implemention of differential attention."""
|
||||
"""Re-implemention of differential attention from the Differential Transformer paper
|
||||
(https://arxiv.org/abs/2410.05258)."""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -27,6 +28,18 @@ except ImportError:
|
||||
|
||||
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
Repeats key/value heads to match the number of query heads in multi-head attention.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape `(batch_size, num_kv_heads, seq_len, head_dim)`.
|
||||
n_rep: Number of times to repeat each head.
|
||||
|
||||
Returns:
|
||||
Tensor with repeated heads of shape `(batch_size, num_kv_heads * n_rep,
|
||||
seq_len, head_dim)`.
|
||||
If `n_rep` is 1, returns the input tensor unchanged.
|
||||
"""
|
||||
batch_size, n_kv_heads, slen, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
@@ -37,14 +50,48 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
)
|
||||
|
||||
|
||||
def lambda_init_fn(depth):
|
||||
def lambda_init_fn(depth: int) -> float:
|
||||
"""
|
||||
Lambda mixing parameter init function from the "Differential Transformer" paper.
|
||||
|
||||
Args:
|
||||
depth: Index of layer to init lambda parameter.
|
||||
|
||||
Returns:
|
||||
Lambda initialization value (decreasing with `depth`).
|
||||
"""
|
||||
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
||||
|
||||
|
||||
class LlamaDifferentialAttentionBase(nn.Module):
|
||||
"""Base class for differential attention implementations."""
|
||||
"""
|
||||
Base class for differential attention implementations.
|
||||
|
||||
def __init__(self, config: Any, layer_idx: int):
|
||||
This class implements the core differential attention mechanism used in Llama models.
|
||||
It supports both split heads and double projection modes for attention computation.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Any, layer_idx: int) -> None:
|
||||
"""
|
||||
Initializes the differential attention module.
|
||||
|
||||
Args:
|
||||
config: Model configuration object containing hyperparameters, including:
|
||||
- hidden_size: The size of hidden states
|
||||
- num_attention_heads: Number of attention heads
|
||||
- num_key_value_heads: Number of key/value heads
|
||||
- attention_bias: Whether to use bias in attention projections
|
||||
- split_heads: Whether to use split heads mode
|
||||
- rms_norm_eps: Epsilon for RMS normalization
|
||||
layer_idx: The index of this layer in the model
|
||||
|
||||
Note:
|
||||
The initialization process consists of four steps:
|
||||
1. Configuration initialization (`_init_config`)
|
||||
2. Projection layers initialization (`_init_projections`)
|
||||
3. Differential parameters initialization (`_init_differential_params`)
|
||||
4. Normalization layers initialization (`_init_normalization`)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
@@ -53,8 +100,24 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
self._init_differential_params()
|
||||
self._init_normalization()
|
||||
|
||||
def _init_config(self, layer_idx: int):
|
||||
"""Initialize configuration parameters."""
|
||||
# For logging
|
||||
self.attn1 = None
|
||||
self.attn2 = None
|
||||
self.lambda_full = None
|
||||
|
||||
def _init_config(self, layer_idx: int) -> None:
|
||||
"""
|
||||
Initializes configuration parameters for the attention layer. Sets up various
|
||||
dimension sizes and head counts based on the provided config. Handles both
|
||||
split heads and double projection modes.
|
||||
|
||||
Args:
|
||||
layer_idx: Index of the current layer.
|
||||
|
||||
Note:
|
||||
In split heads mode, the number of heads is divided by 2 (rounding down),
|
||||
which differs from the original implementation that required an even number.
|
||||
"""
|
||||
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
self.base_num_heads = self.config.num_attention_heads
|
||||
self.base_num_kv_heads = self.config.num_key_value_heads
|
||||
@@ -62,26 +125,26 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
if self.config.split_heads:
|
||||
# Split heads mode - single projections
|
||||
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
|
||||
# implementation, which asserts `self.base_num_heads` is even
|
||||
self.heads_per_component = self.base_num_heads // 2
|
||||
self.kv_heads_per_component = self.base_num_kv_heads // 2
|
||||
self.value_head_dim = 2 * self.head_dim
|
||||
else:
|
||||
# Double projection mode
|
||||
self.heads_per_component = self.base_num_heads
|
||||
self.kv_heads_per_component = self.base_num_kv_heads
|
||||
self.value_head_dim = self.head_dim
|
||||
|
||||
def _init_projections(self):
|
||||
"""Initialize Q, K, V projections."""
|
||||
def _init_projections(self) -> None:
|
||||
"""
|
||||
Initializes the query, key, value, and output projection layers.
|
||||
|
||||
Creates linear transformations for Q, K, V projections with dimensions
|
||||
depending on whether split heads or double projection mode is used.
|
||||
The output projection combines the attention heads back to model dimension.
|
||||
"""
|
||||
if self.config.split_heads:
|
||||
# Split heads mode - single projections
|
||||
q_out_dim = self.config.hidden_size
|
||||
k_out_dim = self.head_dim * self.base_num_kv_heads
|
||||
else:
|
||||
# Double projection mode
|
||||
q_out_dim = self.config.hidden_size * 2
|
||||
k_out_dim = self.head_dim * self.base_num_kv_heads * 2
|
||||
|
||||
@@ -102,8 +165,15 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
bias=self.config.attention_bias,
|
||||
)
|
||||
|
||||
def _init_differential_params(self):
|
||||
"""Initialize differential attention parameters."""
|
||||
def _init_differential_params(self) -> None:
|
||||
"""
|
||||
Initializes parameters specific to differential attention.
|
||||
|
||||
Creates learnable parameters for the differential attention mechanism:
|
||||
- Lambda parameters for queries and keys
|
||||
- Initial lambda value based on layer index
|
||||
- Rotary position embedding layer
|
||||
"""
|
||||
self.lambda_init = nn.Parameter(
|
||||
torch.full((), lambda_init_fn(self.layer_idx)),
|
||||
requires_grad=False,
|
||||
@@ -122,19 +192,42 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
||||
|
||||
def _init_normalization(self):
|
||||
"""Initialize normalization layers."""
|
||||
def _init_normalization(self) -> None:
|
||||
"""
|
||||
Initializes normalization layers for the attention mechanism.
|
||||
|
||||
Sets up either RMS normalization or identity transformation based on config.
|
||||
The normalization is applied to the sublayer output if enabled.
|
||||
"""
|
||||
sublayer_norm = getattr(self.config, "sublayer_norm", True)
|
||||
if sublayer_norm:
|
||||
self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps)
|
||||
else:
|
||||
self.subln = nn.Identity()
|
||||
|
||||
def _prepare_attention_inputs(self, hidden_states: torch.Tensor):
|
||||
"""Prepare inputs for attention computation."""
|
||||
def _prepare_attention_inputs(
|
||||
self, hidden_states: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Prepares input tensors for attention computation.
|
||||
|
||||
Projects input hidden states to query, key, and value spaces, then reshapes
|
||||
them for multi-head attention processing.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor of shape `(batch_size, seq_len,
|
||||
hidden_size)`.
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing:
|
||||
- q1: Positive attention query component
|
||||
- q2: Negative attention query component
|
||||
- k1: Positive attention key component
|
||||
- k2: Negative attention key component
|
||||
- v: Value tensor
|
||||
"""
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Project and split
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
@@ -158,9 +251,41 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
return q1, q2, k1, k2, v
|
||||
|
||||
def _apply_rotary_embeddings(
|
||||
self, q1, q2, k1, k2, position_ids, position_embeddings
|
||||
):
|
||||
"""Apply rotary embeddings to queries and keys."""
|
||||
self,
|
||||
q1: torch.Tensor,
|
||||
q2: torch.Tensor,
|
||||
k1: torch.Tensor,
|
||||
k2: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
"""
|
||||
Applies rotary positional embeddings to queries and keys.
|
||||
|
||||
Args:
|
||||
q1: Positive attention query component.
|
||||
q2: Negative attention query component.
|
||||
k1: Positive attention key component.
|
||||
k2: Negative attention key component.
|
||||
position_ids: Token position indices.
|
||||
position_embeddings: Pre-computed rotary embeddings (cos, sin).
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing:
|
||||
- q1: Positive attention query with positional encoding.
|
||||
- q2: Negative attention query with positional encoding.
|
||||
- k1: Positive attention key with positional encoding.
|
||||
- k2: Negative attention key with positional encoding.
|
||||
- cos: Cosine part of rotary embeddings.
|
||||
- sin: Sine part of rotary embeddings.
|
||||
"""
|
||||
if position_embeddings is None:
|
||||
LOG.warning(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
@@ -177,14 +302,36 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
|
||||
return q1, q2, k1, k2, cos, sin
|
||||
|
||||
def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs):
|
||||
"""Handle caching for autoregressive generation."""
|
||||
def _handle_cache(
|
||||
self,
|
||||
k1: torch.Tensor,
|
||||
k2: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
past_key_value: Cache | None,
|
||||
cache_kwargs: dict,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Handles key-value caching for autoregressive generation and the repetition of
|
||||
key-value heads to match the number of query heads.
|
||||
|
||||
Args:
|
||||
k1: Positive attention key component.
|
||||
k2: Negative attention key component.
|
||||
v: Value tensor.
|
||||
past_key_value: Cache object for storing previous key-value pairs.
|
||||
cache_kwargs: Additional arguments for cache handling.
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing:
|
||||
- k1: Processed positive attention key component.
|
||||
- k2: Processed negative attention key component.
|
||||
- v: Processed value tensor.
|
||||
"""
|
||||
if past_key_value is not None:
|
||||
k = torch.stack([k1, k2], dim=1)
|
||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||
k1, k2 = k.unbind(dim=1)
|
||||
|
||||
# Repeat KV heads to match number of query heads
|
||||
k1 = repeat_kv(k1, self.num_key_value_groups)
|
||||
k2 = repeat_kv(k2, self.num_key_value_groups)
|
||||
v = repeat_kv(v, self.num_key_value_groups)
|
||||
@@ -193,39 +340,90 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
|
||||
return k1, k2, v
|
||||
|
||||
def _compute_lambda(self, q1):
|
||||
"""Compute lambda values for differential attention."""
|
||||
def _compute_lambda(self, q1: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Computes lambda values for differential attention.
|
||||
|
||||
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are
|
||||
computed from the learned parameters.
|
||||
|
||||
Args:
|
||||
q1: Positive attention query component, used for type casting.
|
||||
|
||||
Returns:
|
||||
Computed lambda value for differential attention.
|
||||
"""
|
||||
lambda_1 = torch.exp(
|
||||
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_2 = torch.exp(
|
||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||
).type_as(q1)
|
||||
|
||||
return lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
def _process_attention_output(self, attn, bsz, q_len):
|
||||
"""Process and project attention output."""
|
||||
def _process_attention_output(
|
||||
self, attn: torch.Tensor, bsz: int, q_len: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Processes and projects the attention output. Applies sublayer normalization,
|
||||
scales by (1 - λ_init), and projects back to model dimension.
|
||||
|
||||
Args:
|
||||
attn: Raw attention output.
|
||||
bsz: Batch size.
|
||||
q_len: Query sequence length.
|
||||
|
||||
Returns:
|
||||
Processed attention output of shape (batch_size, seq_len, hidden_size)
|
||||
"""
|
||||
attn = self.subln(attn)
|
||||
attn = attn * (1 - self.lambda_init)
|
||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
|
||||
|
||||
return self.o_proj(attn)
|
||||
|
||||
|
||||
class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
|
||||
"""Standard implementation of differential attention."""
|
||||
"""
|
||||
Standard implementation of differential attention.
|
||||
|
||||
This class implements the standard differential attention mechanism using
|
||||
explicit matrix multiplications for the attention computation.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: Cache | None = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False, # pylint: disable=unused-argument
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""
|
||||
Computes differential attention using standard matrix multiplication operations.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor containing sequence to attend to.
|
||||
attention_mask: Mask to avoid attention on padding tokens.
|
||||
position_ids: Indices of positions for positional embeddings.
|
||||
past_key_value: Cached key and value tensors for autoregressive decoding.
|
||||
output_attentions: Whether to return attention weights.
|
||||
use_cache: Whether to use cached key/value states.
|
||||
cache_position: Position indices for cached states.
|
||||
position_embeddings: Pre-computed positional embeddings.
|
||||
**kwargs: Additional arguments passed to the forward call.
|
||||
|
||||
Returns:
|
||||
tuple containing:
|
||||
- Output tensor after attention computation.
|
||||
- Attention weights if output_attentions is True, else None.
|
||||
- Updated key-value cache if use_cache is True, else None.
|
||||
"""
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
|
||||
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
|
||||
@@ -255,6 +453,11 @@ class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
|
||||
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
|
||||
attn = self._process_attention_output(attn, bsz, q_len)
|
||||
|
||||
# Save for logging
|
||||
self.attn1 = attn1
|
||||
self.attn2 = attn2
|
||||
self.lambda_full = lambda_full
|
||||
|
||||
if output_attentions:
|
||||
attn_weights = attn1 - lambda_full * attn2
|
||||
attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1)
|
||||
@@ -263,27 +466,53 @@ class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
|
||||
|
||||
|
||||
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
|
||||
"""SDPA-based implementation of differential attention."""
|
||||
"""
|
||||
SDPA-based implementation of differential attention.
|
||||
|
||||
This class implements differential attention using PyTorch's scaled_dot_product_attention
|
||||
for improved performance on supported hardware.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: Cache | None = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""
|
||||
Computes differential attention using PyTorch's scaled dot product attention.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor containing sequence to attend to.
|
||||
attention_mask: Mask to avoid attention on padding tokens.
|
||||
position_ids: Indices of positions for positional embeddings.
|
||||
past_key_value: Cached key and value tensors for autoregressive decoding.
|
||||
output_attentions: Whether to return attention weights.
|
||||
use_cache: Whether to use cached key/value states.
|
||||
cache_position: Position indices for cached states.
|
||||
position_embeddings: Pre-computed positional embeddings.
|
||||
**kwargs: Additional arguments passed to the forward call.
|
||||
|
||||
Returns:
|
||||
tuple containing:
|
||||
- Output tensor after attention computation.
|
||||
- None for attention weights (SDPA doesn't support output_attentions).
|
||||
- Updated key-value cache if use_cache is True, else None.
|
||||
"""
|
||||
if output_attentions:
|
||||
LOG.warning(
|
||||
"LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but "
|
||||
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
+ "`output_attentions=True`. Falling back to the eager attention implementation."
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
return LlamaDifferentialAttention.forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@@ -326,15 +555,35 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
|
||||
|
||||
lambda_full = self._compute_lambda(q1)
|
||||
attn = attn1 - lambda_full * attn2
|
||||
|
||||
attn = self._process_attention_output(attn, bsz, q_len)
|
||||
|
||||
# Save for logging
|
||||
self.attn1 = attn1
|
||||
self.attn2 = attn2
|
||||
self.lambda_full = lambda_full
|
||||
|
||||
return attn, None, past_key_value
|
||||
|
||||
|
||||
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
|
||||
"""Flash Attention 2-based implementation of differential attention."""
|
||||
"""
|
||||
Flash Attention 2-based implementation of differential attention.
|
||||
|
||||
This class implements differential attention using Flash Attention 2 for maximum
|
||||
performance on supported hardware.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Initializes the Flash Attention 2 differential attention module.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments passed to parent class.
|
||||
**kwargs: Keyword arguments passed to parent class.
|
||||
|
||||
Raises:
|
||||
ImportError: If flash-attn library is not installed.
|
||||
"""
|
||||
if not FLASH_ATTENTION_AVAILABLE:
|
||||
raise ImportError(
|
||||
"LlamaDifferentialFlashAttention2 requires flash-attn library. "
|
||||
@@ -343,25 +592,46 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: Cache | None = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""
|
||||
Computes differential attention using Flash Attention 2.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor containing sequence to attend to.
|
||||
attention_mask: Mask to avoid attention on padding tokens.
|
||||
position_ids: Indices of positions for positional embeddings.
|
||||
past_key_value: Cached key and value tensors for autoregressive decoding.
|
||||
output_attentions: Whether to return attention weights.
|
||||
use_cache: Whether to use cached key/value states.
|
||||
cache_position: Position indices for cached states.
|
||||
position_embeddings: Pre-computed positional embeddings.
|
||||
**kwargs: Additional arguments passed to the forward call.
|
||||
|
||||
Returns:
|
||||
tuple containing:
|
||||
- Output tensor after attention computation.
|
||||
- None for attention weights (Flash Attention doesn't support output_attentions).
|
||||
- Updated key-value cache if use_cache is True, else None.
|
||||
"""
|
||||
if output_attentions:
|
||||
LOG.warning(
|
||||
"LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but "
|
||||
+ "flash attenion does not support `output_attentions=True`. Falling back "
|
||||
+ "to the eager attention implementation."
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
return LlamaDifferentialAttention.forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@@ -407,6 +677,11 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
|
||||
|
||||
lambda_full = self._compute_lambda(q1)
|
||||
attn = attn1 - lambda_full * attn2
|
||||
|
||||
attn = self._process_attention_output(attn, bsz, q_len)
|
||||
|
||||
# Save for logging
|
||||
self.attn1 = attn1
|
||||
self.attn2 = attn2
|
||||
self.lambda_full = lambda_full
|
||||
|
||||
return attn, None, past_key_value
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
"""Modeling for differential transformers."""
|
||||
"""
|
||||
Modeling for differential transformers.
|
||||
|
||||
This module implements differential attention variants of the LLaMA model,
|
||||
providing various attention implementations for improved performance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||
@@ -18,7 +22,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaDifferentialConfig(LlamaConfig):
|
||||
"""Configuration class for Differential LLaMA model."""
|
||||
"""
|
||||
Configuration class for Differential LLaMA model.
|
||||
|
||||
Extends the base LLaMA configuration with additional parameters for differential
|
||||
attention mechanisms.
|
||||
"""
|
||||
|
||||
model_type = "llama-differential"
|
||||
|
||||
@@ -29,6 +38,15 @@ class LlamaDifferentialConfig(LlamaConfig):
|
||||
zero_init: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize differential LLaMA configuration.
|
||||
|
||||
Args:
|
||||
split_heads: Whether to use split heads mode for attention computation.
|
||||
sublayer_norm: Whether to apply normalization to sublayers.
|
||||
zero_init: Whether to initialize new weights to zero.
|
||||
**kwargs: Additional arguments passed to LlamaConfig.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.split_heads = split_heads
|
||||
self.sublayer_norm = sublayer_norm
|
||||
@@ -42,12 +60,26 @@ class LlamaDifferentialConfig(LlamaConfig):
|
||||
|
||||
|
||||
class LlamaDifferentialModel(LlamaModel):
|
||||
"""LlamaModel with differential attention."""
|
||||
"""
|
||||
LlamaModel with differential attention.
|
||||
|
||||
This class extends the base LLaMA model by replacing standard attention with
|
||||
differential attention mechanisms.
|
||||
"""
|
||||
|
||||
config_class = LlamaDifferentialConfig
|
||||
base_model_prefix = "llama_differential"
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: LlamaDifferentialConfig):
|
||||
"""
|
||||
Initialize a differential LLaMA model.
|
||||
|
||||
Args:
|
||||
config: Configuration object for the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If specified attention implementation is not supported.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# Handle attention implementation
|
||||
@@ -76,11 +108,26 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
for idx, layer in enumerate(self.layers):
|
||||
layer.self_attn = attn_class(config, idx)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
@classmethod
|
||||
# pylint: disable=protected-access
|
||||
def _autoset_attn_implementation(
|
||||
cls, config, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
cls,
|
||||
config: LlamaDifferentialConfig,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> LlamaDifferentialConfig:
|
||||
"""
|
||||
Automatically set the attention implementation based on config.
|
||||
|
||||
Args:
|
||||
config: Model configuration object.
|
||||
**kwargs: Additional arguments (unused).
|
||||
|
||||
Returns:
|
||||
Updated configuration object.
|
||||
|
||||
Raises:
|
||||
ValueError: If specified attention implementation is not supported.
|
||||
"""
|
||||
config._attn_implementation_autoset = True
|
||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
|
||||
@@ -110,10 +157,23 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
@classmethod
|
||||
def from_llama(
|
||||
cls,
|
||||
model: Union[LlamaModel, LlamaForCausalLM],
|
||||
config: Optional[LlamaDifferentialConfig] = None,
|
||||
model: LlamaModel | LlamaForCausalLM,
|
||||
config: LlamaDifferentialConfig | None = None,
|
||||
) -> "LlamaDifferentialModel":
|
||||
"""Convert a LlamaModel to use differential attention."""
|
||||
"""
|
||||
Convert a `LlamaModel` to use differential attention.
|
||||
|
||||
Args:
|
||||
model: Base LLaMA model to convert.
|
||||
config: Configuration for differential attention. If `None`, created from
|
||||
base model config.
|
||||
|
||||
Returns:
|
||||
Converted model with differential attention.
|
||||
|
||||
Raises:
|
||||
ValueError: If number of heads is not even when using `split_heads` mode.
|
||||
"""
|
||||
logger.info(f"Converting {type(model).__name__} to {cls.__name__}")
|
||||
|
||||
# Handle LlamaForCausalLM
|
||||
@@ -182,7 +242,6 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
|
||||
if config.zero_init:
|
||||
logger.debug(f"Layer {layer_idx}: Zero initializing")
|
||||
# Zero out components as needed
|
||||
with torch.no_grad():
|
||||
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
|
||||
new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_()
|
||||
@@ -192,45 +251,60 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
new_layer.self_attn.lambda_k2.zero_()
|
||||
new_layer.self_attn.lambda_init.zero_()
|
||||
else:
|
||||
logger.debug(
|
||||
f"Layer {layer_idx}: Initializing with scale {config.init_scale}"
|
||||
# Mirror weights for second component
|
||||
new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_(
|
||||
old_layer.self_attn.q_proj.weight.data
|
||||
)
|
||||
new_layer.self_attn.k_proj.weight.data[old_k_size:].copy_(
|
||||
old_layer.self_attn.k_proj.weight.data
|
||||
)
|
||||
# Initialize with small random values
|
||||
with torch.no_grad():
|
||||
new_layer.self_attn.q_proj.weight.data[old_q_size:].normal_(
|
||||
0, config.init_scale
|
||||
)
|
||||
new_layer.self_attn.k_proj.weight.data[old_k_size:].normal_(
|
||||
0, config.init_scale
|
||||
)
|
||||
new_layer.self_attn.lambda_q1.normal_(0, config.init_scale)
|
||||
new_layer.self_attn.lambda_k1.normal_(0, config.init_scale)
|
||||
new_layer.self_attn.lambda_q2.normal_(0, config.init_scale)
|
||||
new_layer.self_attn.lambda_k2.normal_(0, config.init_scale)
|
||||
if config.reinit_lambda_init:
|
||||
new_layer.self_attn.lambda_init.normal_(
|
||||
0, config.init_scale
|
||||
).abs_()
|
||||
|
||||
logger.info("Conversion complete")
|
||||
|
||||
return new_model
|
||||
|
||||
|
||||
class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
||||
"""LlamaForCausalLM with differential attention."""
|
||||
"""
|
||||
`LlamaForCausalLM` with differential attention.
|
||||
|
||||
This class extends the base LLaMA causal language model by incorporating
|
||||
differential attention mechanisms.
|
||||
"""
|
||||
|
||||
config_class = LlamaDifferentialConfig
|
||||
base_model_prefix = "llama_differential"
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: LlamaDifferentialConfig):
|
||||
"""
|
||||
Initialize a differential LLaMA model for causal language modeling.
|
||||
|
||||
Args:
|
||||
config: Configuration object for the model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.model = LlamaDifferentialModel(config)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
@classmethod
|
||||
# pylint: disable=protected-access
|
||||
def _autoset_attn_implementation(
|
||||
cls, config, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
cls,
|
||||
config: LlamaDifferentialConfig,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> LlamaDifferentialConfig:
|
||||
"""
|
||||
Automatically set the attention implementation based on config.
|
||||
|
||||
Args:
|
||||
config: Model configuration object.
|
||||
**kwargs: Additional arguments (unused).
|
||||
|
||||
Returns:
|
||||
Updated configuration object.
|
||||
|
||||
Raises:
|
||||
ValueError: If specified attention implementation is not supported.
|
||||
"""
|
||||
config._attn_implementation_autoset = True
|
||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
|
||||
@@ -239,6 +313,7 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
||||
config._attn_implementation = config._attn_implementations[
|
||||
attn_implementation
|
||||
]
|
||||
|
||||
return config
|
||||
|
||||
# If no mapping, validate it's a valid differential type
|
||||
@@ -259,9 +334,22 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
||||
|
||||
@classmethod
|
||||
def from_llama(
|
||||
cls, model: LlamaForCausalLM, config: Optional[LlamaDifferentialConfig] = None
|
||||
cls, model: LlamaForCausalLM, config: LlamaDifferentialConfig | None = None
|
||||
) -> "LlamaDifferentialForCausalLM":
|
||||
"""Convert a LlamaForCausalLM to use differential attention."""
|
||||
"""
|
||||
Convert a `LlamaForCausalLM` to use differential attention.
|
||||
|
||||
Args:
|
||||
model: Base LLaMA model to convert.
|
||||
config: Configuration for differential attention. If `None`, created from
|
||||
base model config.
|
||||
|
||||
Returns:
|
||||
Converted model with differential attention.
|
||||
|
||||
Raises:
|
||||
ValueError: If number of heads is not even when using `split_heads` mode.
|
||||
"""
|
||||
if config is None:
|
||||
config = LlamaDifferentialConfig(**model.config.__dict__)
|
||||
|
||||
@@ -285,7 +373,14 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
||||
return new_model
|
||||
|
||||
|
||||
def register_diff_attn():
|
||||
def register_diff_attn() -> None:
|
||||
"""
|
||||
Register differential attention components with the transformers library.
|
||||
|
||||
This function registers the differential attention configurations and model classes
|
||||
with the Auto* classes from `transformers`, making them available through the
|
||||
standard model loading pipeline.
|
||||
"""
|
||||
# Register configs
|
||||
AutoConfig.register("llama-differential", LlamaDifferentialConfig)
|
||||
|
||||
|
||||
171
src/axolotl/utils/callbacks/differential.py
Normal file
171
src/axolotl/utils/callbacks/differential.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Monitor and log differential attention components during training.
|
||||
|
||||
This module provides a callback for tracking the behavior of differential attention
|
||||
mechanisms, including lambda parameters and attention statistics.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import wandb
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
|
||||
class DifferentialAttentionMonitorCallback(TrainerCallback):
|
||||
"""
|
||||
Callback to monitor differential attention components and lambda parameters.
|
||||
|
||||
This callback tracks attention statistics across all layers and provides detailed
|
||||
monitoring for a specified number of layers evenly spaced through the model.
|
||||
"""
|
||||
|
||||
def __init__(self, log_every: int = 250, num_monitor_layers: int = 3):
|
||||
"""
|
||||
Initialize the differential attention monitor.
|
||||
|
||||
Args:
|
||||
log_every: Number of steps between logging events.
|
||||
num_monitor_layers: Number of individual layers to monitor in detail.
|
||||
"""
|
||||
self.log_every = log_every
|
||||
self.num_monitor_layers = num_monitor_layers
|
||||
self.monitor_layers: list[int] | None = None # Will be set in on_train_begin
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: Any,
|
||||
state: Any,
|
||||
control: Any,
|
||||
model: torch.nn.Module,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Set up layer monitoring at the start of training.
|
||||
|
||||
Args:
|
||||
args: Training arguments.
|
||||
state: Training state.
|
||||
control: Training control object.
|
||||
model: The model being trained.
|
||||
**kwargs: Additional arguments passed by the trainer.
|
||||
"""
|
||||
if is_main_process():
|
||||
num_layers = len(model.model.layers)
|
||||
self.num_monitor_layers = min(self.num_monitor_layers, num_layers)
|
||||
|
||||
stride = (
|
||||
(num_layers - 1) / (self.num_monitor_layers - 1)
|
||||
if self.num_monitor_layers > 1
|
||||
else 0
|
||||
)
|
||||
self.monitor_layers = [
|
||||
round(i * stride) for i in range(self.num_monitor_layers)
|
||||
]
|
||||
print(f"Monitoring layers {self.monitor_layers} in detail")
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_step_end(
|
||||
self, args: Any, state: Any, control: Any, model: torch.nn.Module, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Log attention metrics at the end of each step.
|
||||
|
||||
Collects and logs:
|
||||
- Lambda parameter norms and values.
|
||||
- Attention statistics (mean and std).
|
||||
- Both per-layer and aggregate metrics.
|
||||
|
||||
Args:
|
||||
args: Training arguments.
|
||||
state: Training state.
|
||||
control: Training control object.
|
||||
model: The model being trained.
|
||||
**kwargs: Additional arguments passed by the trainer.
|
||||
"""
|
||||
if not is_main_process() or state.global_step % self.log_every != 0:
|
||||
return
|
||||
|
||||
assert self.monitor_layers is not None
|
||||
|
||||
# Aggregate stats across all layers
|
||||
all_q1_norms = []
|
||||
all_q2_norms = []
|
||||
all_k1_norms = []
|
||||
all_k2_norms = []
|
||||
all_lambda1 = []
|
||||
all_lambda2 = []
|
||||
all_lambda_full = []
|
||||
|
||||
metrics = {}
|
||||
|
||||
for layer_idx, layer in enumerate(model.model.layers):
|
||||
attn = layer.self_attn
|
||||
|
||||
# Collect stats for aggregation
|
||||
all_q1_norms.append(attn.lambda_q1.norm().item())
|
||||
all_q2_norms.append(attn.lambda_q2.norm().item())
|
||||
all_k1_norms.append(attn.lambda_k1.norm().item())
|
||||
all_k2_norms.append(attn.lambda_k2.norm().item())
|
||||
|
||||
lambda1 = torch.exp(torch.sum(attn.lambda_q1 * attn.lambda_k1)).item()
|
||||
lambda2 = torch.exp(torch.sum(attn.lambda_q2 * attn.lambda_k2)).item()
|
||||
all_lambda1.append(lambda1)
|
||||
all_lambda2.append(lambda2)
|
||||
all_lambda_full.append(attn.lambda_full)
|
||||
|
||||
# Log detailed metrics for monitored layers
|
||||
if layer_idx in self.monitor_layers:
|
||||
metrics.update(
|
||||
{
|
||||
f"layer_{layer_idx}/lambda_q1_norm": attn.lambda_q1.norm().item(),
|
||||
f"layer_{layer_idx}/lambda_k1_norm": attn.lambda_k1.norm().item(),
|
||||
f"layer_{layer_idx}/lambda_q2_norm": attn.lambda_q2.norm().item(),
|
||||
f"layer_{layer_idx}/lambda_k2_norm": attn.lambda_k2.norm().item(),
|
||||
f"layer_{layer_idx}/lambda1": lambda1,
|
||||
f"layer_{layer_idx}/lambda2": lambda2,
|
||||
f"layer_{layer_idx}/lambda_init": attn.lambda_init.item(),
|
||||
f"layer_{layer_idx}/lambda_full": lambda1
|
||||
- lambda2
|
||||
+ attn.lambda_init.item(),
|
||||
f"layer_{layer_idx}/attn1_mean": attn.attn1.mean().item(),
|
||||
f"layer_{layer_idx}/attn2_mean": attn.attn2.mean().item(),
|
||||
f"layer_{layer_idx}/attn1_std": attn.attn1.std().item(),
|
||||
f"layer_{layer_idx}/attn2_std": attn.attn2.std().item(),
|
||||
}
|
||||
)
|
||||
|
||||
# Add aggregate metrics
|
||||
metrics.update(
|
||||
{
|
||||
"aggregate/lambda_q1_norm_mean": torch.tensor(all_q1_norms)
|
||||
.mean()
|
||||
.item(),
|
||||
"aggregate/lambda_q1_norm_std": torch.tensor(all_q1_norms).std().item(),
|
||||
"aggregate/lambda_q2_norm_mean": torch.tensor(all_q2_norms)
|
||||
.mean()
|
||||
.item(),
|
||||
"aggregate/lambda_q2_norm_std": torch.tensor(all_q2_norms).std().item(),
|
||||
"aggregate/lambda_k1_norm_mean": torch.tensor(all_k1_norms)
|
||||
.mean()
|
||||
.item(),
|
||||
"aggregate/lambda_k1_norm_std": torch.tensor(all_k1_norms).std().item(),
|
||||
"aggregate/lambda_k2_norm_mean": torch.tensor(all_k2_norms)
|
||||
.mean()
|
||||
.item(),
|
||||
"aggregate/lambda_k2_norm_std": torch.tensor(all_k2_norms).std().item(),
|
||||
"aggregate/lambda1_mean": torch.tensor(all_lambda1).mean().item(),
|
||||
"aggregate/lambda1_std": torch.tensor(all_lambda1).std().item(),
|
||||
"aggregate/lambda2_mean": torch.tensor(all_lambda2).mean().item(),
|
||||
"aggregate/lambda2_std": torch.tensor(all_lambda2).std().item(),
|
||||
"aggregate/lambda_full_mean": torch.tensor(all_lambda_full)
|
||||
.mean()
|
||||
.item(),
|
||||
"aggregate/lambda_full_std": torch.tensor(all_lambda_full).std().item(),
|
||||
}
|
||||
)
|
||||
|
||||
wandb.log(metrics, step=state.global_step)
|
||||
@@ -69,7 +69,7 @@ def test_conversion_cli_debug(tmp_path: Path, base_config):
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(debug=True, init_scale=0.1)
|
||||
cli_args = ConvertDiffTransformerCliArgs(debug=True)
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert not debug_info["generations_match"]
|
||||
|
||||
Reference in New Issue
Block a user