adding yaml dumper preserving input config format

This commit is contained in:
Dan Saunders
2024-12-20 20:39:40 +00:00
parent e0adf11b76
commit 2717b97103
17 changed files with 579 additions and 707 deletions

View File

@@ -14,7 +14,8 @@ from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.differential_transformer.convert import convert_to_diff_attn
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attn
from axolotl.utils.yaml import dump_yaml_preserved_order
LOG = logging.getLogger(__name__)
@@ -51,7 +52,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
raise
def convert_differential_transformer(cfg, cli_args, config_path):
def convert_diff_transformer(cfg, cli_args, config_path):
debug_info = {}
# Load model and tokenizer
@@ -114,16 +115,23 @@ def convert_differential_transformer(cfg, cli_args, config_path):
LOG.info("Saving updated config to %s", output_config_path)
with open(config_path, "r", encoding="utf-8") as file:
data = yaml.safe_load(file) or {}
modified_cfg = yaml.safe_load(file) or {}
data["base_model"] = cfg.output_dir
data["differential_attention"] = True
data["plugins"] = [
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin"
]
modified_cfg["base_model"] = cfg.output_dir
modified_cfg["diff_attention"] = True
plugin_class = (
"axolotl.integrations.diff_transformer.DifferentialTransformerPlugin"
)
if "plugins" in modified_cfg:
modified_cfg["plugins"].append(plugin_class)
else:
modified_cfg["plugins"] = [plugin_class]
with open(output_config_path, "w", encoding="utf-8") as file:
yaml.dump(data, file)
dump_yaml_preserved_order(
data=modified_cfg,
reference_yaml_path=config_path,
output_path=output_config_path,
)
else:
LOG.info("Not saving converted model to disk")
LOG.info("Pass --output-dir path/to/save to save model")
@@ -191,7 +199,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
convert_differential_transformer(cfg, cli_args, config)
convert_diff_transformer(cfg, cli_args, config)
if __name__ == "__main__":

View File

@@ -252,11 +252,11 @@ def merge_lora(
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_differential_transformer(config: str, **kwargs):
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.integrations.convert_differential_transformer import do_cli
from axolotl.cli.integrations.convert_diff_transformer import do_cli
do_cli(config=config, **kwargs)

View File

@@ -57,7 +57,7 @@ class EvaluateCliArgs:
@dataclass
class ConvertDiffTransformerCliArgs:
"""
dataclass with arguments for convert-differential-transformer CLI
dataclass with arguments for convert-diff-transformer CLI
"""
debug: bool = field(default=False)

View File

@@ -0,0 +1,10 @@
# Differential Transformer
### Usage
```yaml
plugins:
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
diff_attention: true
```

View File

@@ -13,11 +13,11 @@ class DifferentialTransformerPlugin(BasePlugin):
"""
def get_input_args(self):
return "axolotl.integrations.differential_transformer.args.DifferentialTransformerArgs"
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
def pre_model_load(self, cfg):
"""Apply differential attention patch before model loading if enabled."""
if cfg.differential_attention:
if cfg.diff_attention:
from axolotl.monkeypatch.attention.differential import (
patch_llama_attention_classes,
)

View File

@@ -11,4 +11,4 @@ LOG = logging.getLogger(__name__)
class DifferentialTransformerArgs(BaseModel):
"""Input args for differential transformer."""
differential_attention: Optional[bool] = None
diff_attention: Optional[bool] = None

View File

@@ -11,7 +11,7 @@ from transformers.models.llama.modeling_llama import (
LlamaSdpaAttention,
)
from .differential_attention import (
from .diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,

View File

@@ -0,0 +1,375 @@
"""Re-implemention of differential attention."""
# pylint: disable=invalid-name
import logging
import math
from typing import Any, Optional, Tuple
import torch
import torch.nn.functional as F
from flash_attn.flash_attn_interface import flash_attn_func
from torch import nn
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
batch_size, n_kv_heads, slen, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, None, :, :]
.expand(batch_size, n_kv_heads, n_rep, slen, head_dim)
.reshape(batch_size, n_kv_heads * n_rep, slen, head_dim)
)
def lambda_init_fn(depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)
class DifferentialAttentionBase(nn.Module):
"""Base class for differential attention implementations."""
def __init__(self, config: Any, layer_idx: int):
super().__init__()
self._init_config(config, layer_idx)
self._init_projections()
self._init_differential_params()
self._init_normalization(config)
def _init_config(self, config: Any, layer_idx: int):
"""Initialize configuration parameters."""
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.base_num_heads = config.num_attention_heads
self.base_num_kv_heads = config.num_key_value_heads
self.layer_idx = layer_idx
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.split_heads = config.split_heads
if config.split_heads:
# Split heads mode - single projections
self.head_dim = config.hidden_size // config.num_attention_heads // 2
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
# implementation, which asserts `self.base_num_heads` is even.
self.heads_per_component = self.base_num_heads // 2
self.value_head_dim = 2 * self.head_dim
else:
# Double projection mode
self.head_dim = config.hidden_size // config.num_attention_heads
self.heads_per_component = self.base_num_heads
self.value_head_dim = self.head_dim
def _init_projections(self):
"""Initialize Q, K, V projections."""
if self.split_heads:
# Split heads mode - single projections
q_out_dim = self.hidden_size
k_out_dim = self.hidden_size // self.base_num_heads * self.base_num_kv_heads
else:
# Double projection mode
q_out_dim = self.hidden_size * 2
k_out_dim = (
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2
)
self.q_proj = nn.Linear(self.hidden_size, q_out_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, k_out_dim, bias=False)
self.v_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
bias=False,
)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
def _init_differential_params(self):
"""Initialize differential attention parameters."""
self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx)),
requires_grad=False,
)
self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_k1 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_q2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.rotary_emb = LlamaRotaryEmbedding(
self.max_position_embeddings, self.head_dim, self.rope_theta
)
def _init_normalization(self, config):
"""Initialize normalization layers."""
sublayer_norm = getattr(config, "sublayer_norm", True)
self.subln = (
LlamaRMSNorm(self.value_head_dim, eps=1e-5)
if sublayer_norm
else nn.Identity()
)
def _prepare_attention_inputs(self, hidden_states: torch.Tensor):
"""Prepare inputs for attention computation."""
bsz, q_len, _ = hidden_states.size()
# Project and split
qp = self.q_proj(hidden_states)
kp = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q1, q2 = qp.chunk(2, dim=-1)
k1, k2 = kp.chunk(2, dim=-1)
# Reshape
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
v = v.view(bsz, q_len, -1, self.value_head_dim).transpose(1, 2)
return q1, q2, k1, k2, v
def _apply_rotary_embeddings(
self, q1, q2, k1, k2, position_ids, position_embeddings
):
"""Apply rotary embeddings to queries and keys."""
if position_embeddings is None:
if position_ids is None:
position_ids = torch.arange(q1.size(-2), device=q1.device)
cos, sin = self.rotary_emb(q1, position_ids)
else:
cos, sin = position_embeddings
if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
return q1, q2, k1, k2, cos, sin
def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs):
"""Handle caching for autoregressive generation."""
if past_key_value is not None:
k = torch.stack([k1, k2], dim=1)
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1)
# Repeat KV heads
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
return k1, k2, v
def _compute_lambda(self, q1):
"""Compute lambda values for differential attention."""
lambda_1 = torch.exp(
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
).type_as(q1)
lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1)
return lambda_1 - lambda_2 + self.lambda_init
def _process_attention_output(self, attn, bsz, q_len):
"""Process and project attention output."""
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
return self.o_proj(attn)
class LlamaDifferentialAttention(DifferentialAttentionBase):
"""Standard implementation of differential attention."""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, # pylint: disable=unused-argument
):
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# Standard attention computation
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
attn1 = attn1 + causal_mask
attn2 = attn2 + causal_mask
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
dropout_p = self.attention_dropout if self.training else 0.0
attn1 = F.dropout(attn1, p=dropout_p, training=self.training)
attn2 = F.dropout(attn2, p=dropout_p, training=self.training)
lambda_full = self._compute_lambda(q1)
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
attn = self._process_attention_output(attn, bsz, q_len)
if output_attentions:
return attn, attn1 - lambda_full * attn2, past_key_value
return attn, None, past_key_value
class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
"""SDPA-based implementation of differential attention."""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, # pylint: disable=unused-argument
):
if output_attentions:
return LlamaDifferentialAttention.forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# SDPA-specific attention computation
causal_mask = (
None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]]
)
is_causal = attention_mask is None and q_len > 1
dropout_p = self.attention_dropout if self.training else 0.0
if q1.device.type == "cuda" and causal_mask is not None:
q1, q2 = q1.contiguous(), q2.contiguous()
k1, k2 = k1.contiguous(), k2.contiguous()
v = v.contiguous()
attn1 = F.scaled_dot_product_attention(
q1, k1, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
)
attn2 = F.scaled_dot_product_attention(
q2, k2, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
)
lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len)
return attn, None, past_key_value
class LlamaDifferentialFlashAttention2(DifferentialAttentionBase):
"""Flash Attention 2-based implementation of differential attention."""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, # pylint: disable=unused-argument
):
if output_attentions:
return LlamaDifferentialAttention.forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# Flash Attention specific processing
q1, q2 = q1.transpose(1, 2), q2.transpose(1, 2)
k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2)
v = v.transpose(1, 2)
dropout_p = self.attention_dropout if self.training else 0.0
if self.split_heads:
v1, v2 = v.chunk(2, dim=-1)
attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True)
attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = flash_attn_func(q2, k2, v1, dropout_p=dropout_p, causal=True)
attn22 = flash_attn_func(q2, k2, v2, dropout_p=dropout_p, causal=True)
attn2 = torch.cat([attn21, attn22], dim=-1)
else:
attn1 = flash_attn_func(q1, k1, v, dropout_p=dropout_p, causal=True)
attn2 = flash_attn_func(q2, k2, v, dropout_p=dropout_p, causal=True)
attn1, attn2 = attn1.transpose(1, 2), attn2.transpose(1, 2)
lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len)
return attn, None, past_key_value

View File

@@ -1,10 +0,0 @@
# Differential Transformer
### Usage
```yaml
plugins:
- axolotl.integrations.differential_transformer.DifferentialTransformerPlugin
differential_attention: true
```

View File

@@ -1,641 +0,0 @@
"""Re-implemention of differential attention."""
# pylint: disable=invalid-name
import logging
import math
from typing import Any, Optional, Tuple
import torch
import torch.nn.functional as F
import transformers
from flash_attn.flash_attn_interface import flash_attn_func
from torch import nn
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
batch_size, n_kv_heads, slen, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, None, :, :]
.expand(batch_size, n_kv_heads, n_rep, slen, head_dim)
.reshape(batch_size, n_kv_heads * n_rep, slen, head_dim)
)
def lambda_init_fn(depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)
class LlamaDifferentialAttention(nn.Module):
"""Differential Attention implementation as described in the Diff Transformer paper.
This implements a modified attention mechanism that computes the difference between
two attention patterns, scaled by learned lambda parameters. The mechanism helps
reduce noise in the attention weights for irrelevant / less relevant tokens.
Key components:
- Split head dimension for differential computation
- Learned lambda parameters that control attention scaling
- Sublayer normalization on the attention output
See:
- https://arxiv.org/abs/2410.05258
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
Args:
config: Model configuration object containing hidden size, number of heads etc.
layer_idx: Index of this layer in the transformer stack
dtype: Data type for the layer parameters
"""
def __init__(
self,
config: Any,
layer_idx: int,
):
super().__init__()
# Base model config
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.base_num_heads = config.num_attention_heads
self.base_num_kv_heads = config.num_key_value_heads
if config.split_heads:
self.head_dim = config.hidden_size // config.num_attention_heads // 2
else:
self.head_dim = config.hidden_size // config.num_attention_heads
self.layer_idx = layer_idx
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.split_heads = config.split_heads
if config.split_heads:
# Split heads mode
# assert (
# self.base_num_heads % 2 == 0
# ), "Number of heads must be even for splitting"
self.heads_per_component = self.base_num_heads // 2
# Single projections
self.q_proj = nn.Linear(
self.hidden_size,
self.hidden_size,
bias=False,
)
self.k_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
bias=False,
)
else:
# Double projection mode
self.heads_per_component = self.base_num_heads
# Double-sized projections
self.q_proj = nn.Linear(
self.hidden_size,
self.hidden_size * 2,
bias=False,
)
self.k_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
bias=False,
)
# Single V projection
self.v_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
bias=False,
)
# Output projection
self.o_proj = nn.Linear(
self.hidden_size,
self.hidden_size,
bias=False,
)
# Initialize differential attention parameters
self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx)),
requires_grad=False,
)
self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_k1 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_q2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
sublayer_norm = getattr(config, "sublayer_norm", True)
if self.split_heads:
subln_dim = 2 * self.head_dim
else:
subln_dim = self.head_dim
self.subln = (
LlamaRMSNorm(hidden_size=subln_dim, eps=1e-5)
if sublayer_norm
else nn.Identity()
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, # pylint: disable=unused-argument
) -> tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[tuple[torch.Tensor, torch.Tensor]],
]:
bsz, q_len, _ = hidden_states.size()
# Project to Q1,Q2 and K1,K2
qp = self.q_proj(hidden_states)
kp = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Split into Q1,Q2 and K1,K2
q1, q2 = qp.chunk(2, dim=-1)
k1, k2 = kp.chunk(2, dim=-1)
# Reshape Q1,Q2 for attention
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape K1,K2 for attention
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape V
if self.split_heads:
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
else:
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Apply rotary embeddings
if position_embeddings is None:
if position_ids is None:
position_ids = torch.arange(q_len, device=q1.device)
cos, sin = self.rotary_emb(q1, position_ids)
else:
cos, sin = position_embeddings
if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k = torch.stack([k1, k2], dim=1)
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1)
# Repeat KV heads to match Q heads
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
# Calculate attention scores for both parts
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
attn1 = attn1 + causal_mask
attn2 = attn2 + causal_mask
# Apply softmax
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
# Apply dropout
attn1 = F.dropout(attn1, p=self.attention_dropout, training=self.training)
attn2 = F.dropout(attn2, p=self.attention_dropout, training=self.training)
# Calculate lambda
lambda_1 = torch.exp(
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
).type_as(q1)
lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
# Compute differential attention (following paper's formula)
attn_weights = attn1 - lambda_full * attn2
# Apply attention weights to values
attn = torch.matmul(attn_weights, v)
# Apply sublayer norm and scaling
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
# Reshape to output
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn = self.o_proj(attn)
if output_attentions:
return attn, attn_weights, past_key_value
return attn, None, past_key_value
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
"""Differential Attention implementation as described in the Diff Transformer paper.
This implements the same logic as `LlamaDifferentialAttention`, but uses
`scaled_dot_product_attention` instead of "manually" computing it under the hood.
This implements a modified attention mechanism that computes the difference between
two attention patterns, scaled by learned lambda parameters. The mechanism helps
reduce noise in the attention weights for irrelevant / less relevant tokens.
Key components:
- Split head dimension for differential computation
- Learned lambda parameters that control attention scaling
- Sublayer normalization on the attention output
See:
- https://arxiv.org/abs/2410.05258
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
Args:
config: Model configuration object containing hidden size, number of heads etc.
layer_idx: Index of this layer in the transformer stack
dtype: Data type for the layer parameters
"""
# pylint: disable=duplicate-code
def forward(
self,
hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size]
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, # pylint: disable=unused-argument
) -> tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[tuple[torch.Tensor, torch.Tensor]],
]:
if output_attentions:
transformers.logger.warning_once(
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
# Project to Q1,Q2 and K1,K2
qp = self.q_proj(hidden_states)
kp = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Split into Q1,Q2 and K1,K2
q1, q2 = qp.chunk(2, dim=-1)
k1, k2 = kp.chunk(2, dim=-1)
# Reshape Q1,Q2 for attention
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape K1,K2 for attention
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape V
if self.split_heads:
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
else:
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Apply rotary embeddings
if position_embeddings is None:
if position_ids is None:
position_ids = torch.arange(q_len, device=q1.device)
cos, sin = self.rotary_emb(q1, position_ids)
else:
cos, sin = position_embeddings
if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k = torch.stack([k1, k2], dim=1)
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1)
# Repeat KV heads to match Q heads
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
causal_mask = None
if attention_mask is not None:
causal_mask = attention_mask
causal_mask = causal_mask[:, :, :, : k1.shape[-2]]
# SDPA with memory-efficient backend requires contiguous inputs on CUDA
if q1.device.type == "cuda" and causal_mask is not None:
q1, q2 = q1.contiguous(), q2.contiguous()
k1, k2 = k1.contiguous(), k2.contiguous()
v = v.contiguous()
# Calculate attention using SDPA
is_causal = attention_mask is None and q_len > 1
dropout_p = self.attention_dropout if self.training else 0.0
attn1 = F.scaled_dot_product_attention(
q1,
k1,
v,
attn_mask=causal_mask,
dropout_p=dropout_p,
is_causal=is_causal,
)
attn2 = F.scaled_dot_product_attention(
q2,
k2,
v,
attn_mask=causal_mask,
dropout_p=dropout_p,
is_causal=is_causal,
)
# Calculate lambda
lambda_1 = torch.exp(
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
).type_as(q1)
lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
# Combine the attention outputs
attn = attn1 - lambda_full * attn2
# Apply sublayer norm and scaling
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
# Reshape to output
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn = self.o_proj(attn)
if output_attentions:
return (
attn,
None,
past_key_value,
) # Note: can't return attn_weights with SDPA
return attn, None, past_key_value
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
"""Differential Attention implementation using Flash Attention 2.
This implements the same logic as `LlamaDifferentialAttention`, but uses
Flash Attention 2 for more efficient computation.
This implements a modified attention mechanism that computes the difference between
two attention patterns, scaled by learned lambda parameters. The mechanism helps
reduce noise in the attention weights for irrelevant / less relevant tokens.
Key components:
- Split head dimension for differential computation
- Learned lambda parameters that control attention scaling
- Sublayer normalization on the attention output
- Flash Attention 2 for efficient attention computation
See:
- https://arxiv.org/abs/2410.05258
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
Args:
config: Model configuration object containing hidden size, number of heads etc.
layer_idx: Index of this layer in the transformer stack
dtype: Data type for the layer parameters
"""
# pylint: disable=duplicate-code
def forward(
self,
hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size]
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[tuple[torch.Tensor, torch.Tensor]],
]:
if output_attentions:
transformers.logger.warning_once(
"LlamaModel is using LlamaFlashAttention, but Flash Attention does not support `output_attentions=True`. "
"Falling back to the manual attention implementation."
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
# Project to Q1,Q2 and K1,K2
qp = self.q_proj(hidden_states)
kp = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Split into Q1,Q2 and K1,K2
q1, q2 = qp.chunk(2, dim=-1)
k1, k2 = kp.chunk(2, dim=-1)
# Reshape Q1,Q2 for attention
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape K1,K2 for attention
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape V
if self.split_heads:
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
else:
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Apply rotary embeddings
if position_embeddings is None:
if position_ids is None:
position_ids = torch.arange(q_len, device=q1.device)
cos, sin = self.rotary_emb(q1, position_ids)
else:
cos, sin = position_embeddings
if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k = torch.stack([k1, k2], dim=1)
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1)
# Repeat KV heads to match Q heads
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
q1 = q1.transpose(1, 2)
q2 = q2.transpose(1, 2)
k1 = k1.transpose(1, 2)
k2 = k2.transpose(1, 2)
v = v.transpose(1, 2)
# Calculate attention using Flash Attention
dropout_p = self.attention_dropout if self.training else 0.0
if self.split_heads:
v1, v2 = v.chunk(2, dim=-1)
attn11 = flash_attn_func(
q1,
k1,
v1,
dropout_p=dropout_p,
causal=True,
)
attn12 = flash_attn_func(
q1,
k1,
v2,
dropout_p=dropout_p,
causal=True,
)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = flash_attn_func(
q2,
k2,
v1,
dropout_p=dropout_p,
causal=True,
)
attn22 = flash_attn_func(
q2,
k2,
v2,
dropout_p=dropout_p,
causal=True,
)
attn2 = torch.cat([attn21, attn22], dim=-1)
else:
attn1 = flash_attn_func(
q1,
k1,
v,
dropout_p=dropout_p,
causal=True,
)
attn2 = flash_attn_func(
q2,
k2,
v,
dropout_p=dropout_p,
causal=True,
)
attn1 = attn1.transpose(1, 2)
attn2 = attn2.transpose(1, 2)
# Calculate lambda
lambda_1 = torch.exp(
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
).type_as(q1)
lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
# Combine the attention outputs
attn = attn1 - lambda_full * attn2
# Apply sublayer norm and scaling
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
# Reshape to output
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn = self.o_proj(attn)
if output_attentions:
return (
attn,
None,
past_key_value,
) # Note: can't return attn_weights with Flash Attention
return attn, None, past_key_value

View File

@@ -3,7 +3,7 @@
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
from axolotl.integrations.differential_transformer.differential_attention import (
from axolotl.integrations.diff_transformer.diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,

View File

@@ -714,7 +714,7 @@ class ModelLoader:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
if self.cfg.differential_attention:
if self.cfg.differentiaion:
self.model_kwargs[
"attn_implementation"
] = "differential_flash_attention_2"
@@ -727,7 +727,7 @@ class ModelLoader:
"flash_attention_2"
)
elif self.cfg.sdp_attention:
if self.cfg.differential_attention:
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_sdpa"
@@ -738,7 +738,7 @@ class ModelLoader:
"sdpa"
)
elif self.cfg.eager_attention:
if self.cfg.differential_attention:
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager"
@@ -748,7 +748,7 @@ class ModelLoader:
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
elif self.cfg.differential_attention:
elif self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager"

151
src/axolotl/utils/yaml.py Normal file
View File

@@ -0,0 +1,151 @@
"""Utilities for YAML files."""
from collections import OrderedDict
from typing import Any, Dict, List, Set, Tuple, Union
import yaml
class YAMLOrderTracker:
"""Tracks the order of keys and section breaks in YAML files."""
def __init__(self, yaml_path: str):
self.yaml_path = yaml_path
self.structure, self.needs_break = self._parse_yaml_structure()
def _get_indentation_level(self, line: str) -> int:
"""Get the indentation level of a line."""
return len(line) - len(line.lstrip())
def _parse_yaml_structure(
self,
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
"""Parse the YAML file to extract structure and identify section breaks."""
with open(self.yaml_path, "r", encoding="utf-8") as file:
contents = file.readlines()
structure: OrderedDict = OrderedDict()
needs_break = set() # Track which keys should have a break before them
current_path = []
last_indentation = -1
had_empty_line = False
for line in contents:
# Track empty lines and comments
if not line.strip() or line.strip().startswith("#"):
had_empty_line = True
continue
# Get indentation level and content
indentation = self._get_indentation_level(line)
content = line.strip()
# Skip lines that don't define keys
if ":" not in content:
continue
# Extract key
key = content.split(":")[0].strip()
# If this is a top-level key and we had an empty line, mark it
if indentation == 0:
if had_empty_line:
needs_break.add(key)
had_empty_line = False
# Handle indentation changes
if indentation > last_indentation:
current_path.append(key)
elif indentation < last_indentation:
levels_up = (last_indentation - indentation) // 2
current_path = current_path[:-levels_up]
current_path[-1] = key
else:
if current_path:
current_path[-1] = key
# Update structure
current_dict = structure
for path_key in current_path[:-1]:
if path_key not in current_dict:
current_dict[path_key] = OrderedDict()
current_dict = current_dict[path_key]
if current_path:
if current_path[-1] not in current_dict:
current_dict[current_path[-1]] = OrderedDict()
last_indentation = indentation
return structure, needs_break
class OrderedDumper(yaml.SafeDumper):
"""Custom YAML dumper that maintains dictionary order."""
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
"""Custom representer for dictionaries that maintains order."""
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
"""Reorder a dictionary based on a reference structure."""
ordered = OrderedDict()
# First add keys that are in the reference order
for key in reference_structure:
if key in data:
if isinstance(reference_structure[key], dict) and isinstance(
data[key], dict
):
ordered[key] = reorder_dict(data[key], reference_structure[key])
else:
ordered[key] = data[key]
# Then add any remaining keys that weren't in the reference
for key in data:
if key not in ordered:
ordered[key] = data[key]
return ordered
def dump_yaml_preserved_order(
data: Dict, reference_yaml_path: str, output_path: str
) -> None:
"""Dump YAML file while preserving nested order and normalized spacing."""
# Get reference structure and spacing
tracker = YAMLOrderTracker(reference_yaml_path)
# Reorder the data
ordered_data = reorder_dict(data, tracker.structure)
# Register the custom representer
OrderedDumper.add_representer(dict, ordered_dict_representer)
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
# First dump to string
yaml_str = yaml.dump(
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
)
# Add spacing according to reference
lines = yaml_str.split("\n")
result_lines: List[str] = []
current_line = 0
while current_line < len(lines):
line = lines[current_line]
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
key = line.split(":")[0].strip()
if key in tracker.needs_break:
# Add single empty line before this key
if result_lines and result_lines[-1] != "":
result_lines.append("")
result_lines.append(line)
current_line += 1
# Write the final result
with open(output_path, "w", encoding="utf-8") as file:
file.write("\n".join(result_lines))

View File

@@ -9,9 +9,6 @@ def base_config():
"""Basic config for testing."""
return {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin",
],
"datasets": [
{
"path": "axolotl-ai-co/alpaca_100_test",

View File

@@ -8,9 +8,7 @@ from pytest import approx
from axolotl.cli import load_cfg
from axolotl.cli.evaluate import do_evaluate
from axolotl.cli.integrations.convert_differential_transformer import (
convert_differential_transformer,
)
from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs
@@ -26,7 +24,7 @@ def test_conversion_and_eval_cli(tmp_path: Path, base_config):
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()

View File

@@ -10,23 +10,19 @@ import pytest
import yaml
from axolotl.cli import load_cfg
from axolotl.cli.integrations.convert_differential_transformer import (
convert_differential_transformer,
)
from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer
from axolotl.cli.main import cli
from axolotl.common.cli import ConvertDiffTransformerCliArgs
def test_cli_validation(cli_runner):
# Test missing config file
result = cli_runner.invoke(cli, ["convert-differential-transformer"])
result = cli_runner.invoke(cli, ["convert-diff-transformer"])
assert result.exit_code != 0
assert "Error: Missing argument 'CONFIG'." in result.output
# Test non-existent config file
result = cli_runner.invoke(
cli, ["convert-differential-transformer", "nonexistent.yml"]
)
result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
@@ -37,11 +33,9 @@ def test_basic_execution(cli_runner, tmp_path: Path, base_config):
yaml.dump(base_config, file)
with patch(
"axolotl.cli.integrations.convert_differential_transformer.do_cli"
"axolotl.cli.integrations.convert_diff_transformer.do_cli"
) as mock_do_cli:
result = cli_runner.invoke(
cli, ["convert-differential-transformer", str(config_path)]
)
result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)])
assert result.exit_code == 0
mock_do_cli.assert_called_once()
@@ -56,14 +50,9 @@ def test_conversion_cli_basic(tmp_path: Path, base_config):
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
# Load config the same way do_cli does
cfg = load_cfg(str(config_path))
# Create CLI args
cli_args = ConvertDiffTransformerCliArgs()
# Call convert_differential_transformer directly
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert not debug_info
assert (output_dir / "model.safetensors").exists()
@@ -79,14 +68,9 @@ def test_conversion_cli_debug(tmp_path: Path, base_config):
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
# Load config the same way do_cli does
cfg = load_cfg(str(config_path))
# Create CLI args
cli_args = ConvertDiffTransformerCliArgs(debug=True)
# Call convert_differential_transformer directly
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert not debug_info["generations_match"]
assert not debug_info["match_expected"]
@@ -107,7 +91,7 @@ def test_conversion_cli_reproduce(tmp_path: Path, base_config):
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
@@ -133,7 +117,7 @@ def test_conversion_cli_repoduce_attentions(
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
@@ -155,7 +139,7 @@ def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is False
assert (output_dir / "model.safetensors").exists()