adding yaml dumper preserving input config format
This commit is contained in:
@@ -14,7 +14,8 @@ from transformers import HfArgumentParser
|
|||||||
|
|
||||||
from axolotl.cli import load_cfg, print_axolotl_text_art
|
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
|
||||||
from axolotl.integrations.differential_transformer.convert import convert_to_diff_attn
|
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attn
|
||||||
|
from axolotl.utils.yaml import dump_yaml_preserved_order
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -51,7 +52,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def convert_differential_transformer(cfg, cli_args, config_path):
|
def convert_diff_transformer(cfg, cli_args, config_path):
|
||||||
debug_info = {}
|
debug_info = {}
|
||||||
|
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
@@ -114,16 +115,23 @@ def convert_differential_transformer(cfg, cli_args, config_path):
|
|||||||
LOG.info("Saving updated config to %s", output_config_path)
|
LOG.info("Saving updated config to %s", output_config_path)
|
||||||
|
|
||||||
with open(config_path, "r", encoding="utf-8") as file:
|
with open(config_path, "r", encoding="utf-8") as file:
|
||||||
data = yaml.safe_load(file) or {}
|
modified_cfg = yaml.safe_load(file) or {}
|
||||||
|
|
||||||
data["base_model"] = cfg.output_dir
|
modified_cfg["base_model"] = cfg.output_dir
|
||||||
data["differential_attention"] = True
|
modified_cfg["diff_attention"] = True
|
||||||
data["plugins"] = [
|
plugin_class = (
|
||||||
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin"
|
"axolotl.integrations.diff_transformer.DifferentialTransformerPlugin"
|
||||||
]
|
)
|
||||||
|
if "plugins" in modified_cfg:
|
||||||
|
modified_cfg["plugins"].append(plugin_class)
|
||||||
|
else:
|
||||||
|
modified_cfg["plugins"] = [plugin_class]
|
||||||
|
|
||||||
with open(output_config_path, "w", encoding="utf-8") as file:
|
dump_yaml_preserved_order(
|
||||||
yaml.dump(data, file)
|
data=modified_cfg,
|
||||||
|
reference_yaml_path=config_path,
|
||||||
|
output_path=output_config_path,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
LOG.info("Not saving converted model to disk")
|
LOG.info("Not saving converted model to disk")
|
||||||
LOG.info("Pass --output-dir path/to/save to save model")
|
LOG.info("Pass --output-dir path/to/save to save model")
|
||||||
@@ -191,7 +199,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
|
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
|
||||||
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||||
|
|
||||||
convert_differential_transformer(cfg, cli_args, config)
|
convert_diff_transformer(cfg, cli_args, config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -252,11 +252,11 @@ def merge_lora(
|
|||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
|
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
def convert_differential_transformer(config: str, **kwargs):
|
def convert_diff_transformer(config: str, **kwargs):
|
||||||
"""Convert model attention layers to differential attention layers."""
|
"""Convert model attention layers to differential attention layers."""
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
from axolotl.cli.integrations.convert_differential_transformer import do_cli
|
from axolotl.cli.integrations.convert_diff_transformer import do_cli
|
||||||
|
|
||||||
do_cli(config=config, **kwargs)
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class EvaluateCliArgs:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ConvertDiffTransformerCliArgs:
|
class ConvertDiffTransformerCliArgs:
|
||||||
"""
|
"""
|
||||||
dataclass with arguments for convert-differential-transformer CLI
|
dataclass with arguments for convert-diff-transformer CLI
|
||||||
"""
|
"""
|
||||||
|
|
||||||
debug: bool = field(default=False)
|
debug: bool = field(default=False)
|
||||||
|
|||||||
10
src/axolotl/integrations/diff_transformer/README.md
Normal file
10
src/axolotl/integrations/diff_transformer/README.md
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# Differential Transformer
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
|
||||||
|
|
||||||
|
diff_attention: true
|
||||||
|
```
|
||||||
@@ -13,11 +13,11 @@ class DifferentialTransformerPlugin(BasePlugin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def get_input_args(self):
|
def get_input_args(self):
|
||||||
return "axolotl.integrations.differential_transformer.args.DifferentialTransformerArgs"
|
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
"""Apply differential attention patch before model loading if enabled."""
|
"""Apply differential attention patch before model loading if enabled."""
|
||||||
if cfg.differential_attention:
|
if cfg.diff_attention:
|
||||||
from axolotl.monkeypatch.attention.differential import (
|
from axolotl.monkeypatch.attention.differential import (
|
||||||
patch_llama_attention_classes,
|
patch_llama_attention_classes,
|
||||||
)
|
)
|
||||||
@@ -11,4 +11,4 @@ LOG = logging.getLogger(__name__)
|
|||||||
class DifferentialTransformerArgs(BaseModel):
|
class DifferentialTransformerArgs(BaseModel):
|
||||||
"""Input args for differential transformer."""
|
"""Input args for differential transformer."""
|
||||||
|
|
||||||
differential_attention: Optional[bool] = None
|
diff_attention: Optional[bool] = None
|
||||||
@@ -11,7 +11,7 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
LlamaSdpaAttention,
|
LlamaSdpaAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .differential_attention import (
|
from .diff_attn import (
|
||||||
LlamaDifferentialAttention,
|
LlamaDifferentialAttention,
|
||||||
LlamaDifferentialFlashAttention2,
|
LlamaDifferentialFlashAttention2,
|
||||||
LlamaDifferentialSdpaAttention,
|
LlamaDifferentialSdpaAttention,
|
||||||
375
src/axolotl/integrations/diff_transformer/diff_attn.py
Normal file
375
src/axolotl/integrations/diff_transformer/diff_attn.py
Normal file
@@ -0,0 +1,375 @@
|
|||||||
|
"""Re-implemention of differential attention."""
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
|
from torch import nn
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaRMSNorm,
|
||||||
|
LlamaRotaryEmbedding,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||||
|
batch_size, n_kv_heads, slen, head_dim = x.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return x
|
||||||
|
return (
|
||||||
|
x[:, :, None, :, :]
|
||||||
|
.expand(batch_size, n_kv_heads, n_rep, slen, head_dim)
|
||||||
|
.reshape(batch_size, n_kv_heads * n_rep, slen, head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def lambda_init_fn(depth):
|
||||||
|
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
||||||
|
|
||||||
|
|
||||||
|
class DifferentialAttentionBase(nn.Module):
|
||||||
|
"""Base class for differential attention implementations."""
|
||||||
|
|
||||||
|
def __init__(self, config: Any, layer_idx: int):
|
||||||
|
super().__init__()
|
||||||
|
self._init_config(config, layer_idx)
|
||||||
|
self._init_projections()
|
||||||
|
self._init_differential_params()
|
||||||
|
self._init_normalization(config)
|
||||||
|
|
||||||
|
def _init_config(self, config: Any, layer_idx: int):
|
||||||
|
"""Initialize configuration parameters."""
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.base_num_heads = config.num_attention_heads
|
||||||
|
self.base_num_kv_heads = config.num_key_value_heads
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
self.is_causal = True
|
||||||
|
self.split_heads = config.split_heads
|
||||||
|
|
||||||
|
if config.split_heads:
|
||||||
|
# Split heads mode - single projections
|
||||||
|
self.head_dim = config.hidden_size // config.num_attention_heads // 2
|
||||||
|
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
|
||||||
|
# implementation, which asserts `self.base_num_heads` is even.
|
||||||
|
self.heads_per_component = self.base_num_heads // 2
|
||||||
|
self.value_head_dim = 2 * self.head_dim
|
||||||
|
else:
|
||||||
|
# Double projection mode
|
||||||
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||||
|
self.heads_per_component = self.base_num_heads
|
||||||
|
self.value_head_dim = self.head_dim
|
||||||
|
|
||||||
|
def _init_projections(self):
|
||||||
|
"""Initialize Q, K, V projections."""
|
||||||
|
if self.split_heads:
|
||||||
|
# Split heads mode - single projections
|
||||||
|
q_out_dim = self.hidden_size
|
||||||
|
k_out_dim = self.hidden_size // self.base_num_heads * self.base_num_kv_heads
|
||||||
|
else:
|
||||||
|
# Double projection mode
|
||||||
|
q_out_dim = self.hidden_size * 2
|
||||||
|
k_out_dim = (
|
||||||
|
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(self.hidden_size, q_out_dim, bias=False)
|
||||||
|
self.k_proj = nn.Linear(self.hidden_size, k_out_dim, bias=False)
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
||||||
|
|
||||||
|
def _init_differential_params(self):
|
||||||
|
"""Initialize differential attention parameters."""
|
||||||
|
self.lambda_init = nn.Parameter(
|
||||||
|
torch.full((), lambda_init_fn(self.layer_idx)),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.lambda_q1 = nn.Parameter(
|
||||||
|
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||||
|
)
|
||||||
|
self.lambda_k1 = nn.Parameter(
|
||||||
|
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||||
|
)
|
||||||
|
self.lambda_q2 = nn.Parameter(
|
||||||
|
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||||
|
)
|
||||||
|
self.lambda_k2 = nn.Parameter(
|
||||||
|
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||||
|
)
|
||||||
|
self.rotary_emb = LlamaRotaryEmbedding(
|
||||||
|
self.max_position_embeddings, self.head_dim, self.rope_theta
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_normalization(self, config):
|
||||||
|
"""Initialize normalization layers."""
|
||||||
|
sublayer_norm = getattr(config, "sublayer_norm", True)
|
||||||
|
self.subln = (
|
||||||
|
LlamaRMSNorm(self.value_head_dim, eps=1e-5)
|
||||||
|
if sublayer_norm
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_attention_inputs(self, hidden_states: torch.Tensor):
|
||||||
|
"""Prepare inputs for attention computation."""
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# Project and split
|
||||||
|
qp = self.q_proj(hidden_states)
|
||||||
|
kp = self.k_proj(hidden_states)
|
||||||
|
v = self.v_proj(hidden_states)
|
||||||
|
q1, q2 = qp.chunk(2, dim=-1)
|
||||||
|
k1, k2 = kp.chunk(2, dim=-1)
|
||||||
|
|
||||||
|
# Reshape
|
||||||
|
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
v = v.view(bsz, q_len, -1, self.value_head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
return q1, q2, k1, k2, v
|
||||||
|
|
||||||
|
def _apply_rotary_embeddings(
|
||||||
|
self, q1, q2, k1, k2, position_ids, position_embeddings
|
||||||
|
):
|
||||||
|
"""Apply rotary embeddings to queries and keys."""
|
||||||
|
if position_embeddings is None:
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(q1.size(-2), device=q1.device)
|
||||||
|
cos, sin = self.rotary_emb(q1, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
|
if self.split_heads:
|
||||||
|
cos, _ = cos.chunk(2, dim=2)
|
||||||
|
sin, _ = sin.chunk(2, dim=2)
|
||||||
|
|
||||||
|
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||||
|
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||||
|
|
||||||
|
return q1, q2, k1, k2, cos, sin
|
||||||
|
|
||||||
|
def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs):
|
||||||
|
"""Handle caching for autoregressive generation."""
|
||||||
|
if past_key_value is not None:
|
||||||
|
k = torch.stack([k1, k2], dim=1)
|
||||||
|
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||||
|
k1, k2 = k.unbind(dim=1)
|
||||||
|
|
||||||
|
# Repeat KV heads
|
||||||
|
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
||||||
|
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
||||||
|
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
||||||
|
|
||||||
|
return k1, k2, v
|
||||||
|
|
||||||
|
def _compute_lambda(self, q1):
|
||||||
|
"""Compute lambda values for differential attention."""
|
||||||
|
lambda_1 = torch.exp(
|
||||||
|
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
||||||
|
).type_as(q1)
|
||||||
|
lambda_2 = torch.exp(
|
||||||
|
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||||
|
).type_as(q1)
|
||||||
|
return lambda_1 - lambda_2 + self.lambda_init
|
||||||
|
|
||||||
|
def _process_attention_output(self, attn, bsz, q_len):
|
||||||
|
"""Process and project attention output."""
|
||||||
|
attn = self.subln(attn)
|
||||||
|
attn = attn * (1 - self.lambda_init)
|
||||||
|
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
||||||
|
return self.o_proj(attn)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDifferentialAttention(DifferentialAttentionBase):
|
||||||
|
"""Standard implementation of differential attention."""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False, # pylint: disable=unused-argument
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
|
||||||
|
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
|
||||||
|
q1, q2, k1, k2, position_ids, position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
|
||||||
|
|
||||||
|
# Standard attention computation
|
||||||
|
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
||||||
|
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
|
||||||
|
attn1 = attn1 + causal_mask
|
||||||
|
attn2 = attn2 + causal_mask
|
||||||
|
|
||||||
|
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
|
||||||
|
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
|
||||||
|
|
||||||
|
dropout_p = self.attention_dropout if self.training else 0.0
|
||||||
|
attn1 = F.dropout(attn1, p=dropout_p, training=self.training)
|
||||||
|
attn2 = F.dropout(attn2, p=dropout_p, training=self.training)
|
||||||
|
|
||||||
|
lambda_full = self._compute_lambda(q1)
|
||||||
|
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
|
||||||
|
|
||||||
|
attn = self._process_attention_output(attn, bsz, q_len)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
return attn, attn1 - lambda_full * attn2, past_key_value
|
||||||
|
return attn, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
|
||||||
|
"""SDPA-based implementation of differential attention."""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
if output_attentions:
|
||||||
|
return LlamaDifferentialAttention.forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
|
||||||
|
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
|
||||||
|
q1, q2, k1, k2, position_ids, position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
|
||||||
|
|
||||||
|
# SDPA-specific attention computation
|
||||||
|
causal_mask = (
|
||||||
|
None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]]
|
||||||
|
)
|
||||||
|
is_causal = attention_mask is None and q_len > 1
|
||||||
|
dropout_p = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
if q1.device.type == "cuda" and causal_mask is not None:
|
||||||
|
q1, q2 = q1.contiguous(), q2.contiguous()
|
||||||
|
k1, k2 = k1.contiguous(), k2.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
|
||||||
|
attn1 = F.scaled_dot_product_attention(
|
||||||
|
q1, k1, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
|
||||||
|
)
|
||||||
|
attn2 = F.scaled_dot_product_attention(
|
||||||
|
q2, k2, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
|
||||||
|
)
|
||||||
|
|
||||||
|
lambda_full = self._compute_lambda(q1)
|
||||||
|
attn = attn1 - lambda_full * attn2
|
||||||
|
|
||||||
|
attn = self._process_attention_output(attn, bsz, q_len)
|
||||||
|
return attn, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDifferentialFlashAttention2(DifferentialAttentionBase):
|
||||||
|
"""Flash Attention 2-based implementation of differential attention."""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
if output_attentions:
|
||||||
|
return LlamaDifferentialAttention.forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
|
||||||
|
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
|
||||||
|
q1, q2, k1, k2, position_ids, position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
|
||||||
|
|
||||||
|
# Flash Attention specific processing
|
||||||
|
q1, q2 = q1.transpose(1, 2), q2.transpose(1, 2)
|
||||||
|
k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
dropout_p = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
if self.split_heads:
|
||||||
|
v1, v2 = v.chunk(2, dim=-1)
|
||||||
|
attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True)
|
||||||
|
attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True)
|
||||||
|
attn1 = torch.cat([attn11, attn12], dim=-1)
|
||||||
|
|
||||||
|
attn21 = flash_attn_func(q2, k2, v1, dropout_p=dropout_p, causal=True)
|
||||||
|
attn22 = flash_attn_func(q2, k2, v2, dropout_p=dropout_p, causal=True)
|
||||||
|
attn2 = torch.cat([attn21, attn22], dim=-1)
|
||||||
|
else:
|
||||||
|
attn1 = flash_attn_func(q1, k1, v, dropout_p=dropout_p, causal=True)
|
||||||
|
attn2 = flash_attn_func(q2, k2, v, dropout_p=dropout_p, causal=True)
|
||||||
|
|
||||||
|
attn1, attn2 = attn1.transpose(1, 2), attn2.transpose(1, 2)
|
||||||
|
|
||||||
|
lambda_full = self._compute_lambda(q1)
|
||||||
|
attn = attn1 - lambda_full * attn2
|
||||||
|
|
||||||
|
attn = self._process_attention_output(attn, bsz, q_len)
|
||||||
|
return attn, None, past_key_value
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
# Differential Transformer
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.differential_transformer.DifferentialTransformerPlugin
|
|
||||||
|
|
||||||
differential_attention: true
|
|
||||||
```
|
|
||||||
@@ -1,641 +0,0 @@
|
|||||||
"""Re-implemention of differential attention."""
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
from typing import Any, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import transformers
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_func
|
|
||||||
from torch import nn
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
LlamaRMSNorm,
|
|
||||||
LlamaRotaryEmbedding,
|
|
||||||
apply_rotary_pos_emb,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
||||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
|
||||||
batch_size, n_kv_heads, slen, head_dim = x.shape
|
|
||||||
if n_rep == 1:
|
|
||||||
return x
|
|
||||||
return (
|
|
||||||
x[:, :, None, :, :]
|
|
||||||
.expand(batch_size, n_kv_heads, n_rep, slen, head_dim)
|
|
||||||
.reshape(batch_size, n_kv_heads * n_rep, slen, head_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def lambda_init_fn(depth):
|
|
||||||
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaDifferentialAttention(nn.Module):
|
|
||||||
"""Differential Attention implementation as described in the Diff Transformer paper.
|
|
||||||
|
|
||||||
This implements a modified attention mechanism that computes the difference between
|
|
||||||
two attention patterns, scaled by learned lambda parameters. The mechanism helps
|
|
||||||
reduce noise in the attention weights for irrelevant / less relevant tokens.
|
|
||||||
|
|
||||||
Key components:
|
|
||||||
- Split head dimension for differential computation
|
|
||||||
- Learned lambda parameters that control attention scaling
|
|
||||||
- Sublayer normalization on the attention output
|
|
||||||
|
|
||||||
See:
|
|
||||||
- https://arxiv.org/abs/2410.05258
|
|
||||||
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Model configuration object containing hidden size, number of heads etc.
|
|
||||||
layer_idx: Index of this layer in the transformer stack
|
|
||||||
dtype: Data type for the layer parameters
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: Any,
|
|
||||||
layer_idx: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# Base model config
|
|
||||||
self.attention_dropout = config.attention_dropout
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.base_num_heads = config.num_attention_heads
|
|
||||||
self.base_num_kv_heads = config.num_key_value_heads
|
|
||||||
|
|
||||||
if config.split_heads:
|
|
||||||
self.head_dim = config.hidden_size // config.num_attention_heads // 2
|
|
||||||
else:
|
|
||||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
||||||
|
|
||||||
self.layer_idx = layer_idx
|
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
|
||||||
self.rope_theta = config.rope_theta
|
|
||||||
self.is_causal = True
|
|
||||||
self.split_heads = config.split_heads
|
|
||||||
|
|
||||||
if config.split_heads:
|
|
||||||
# Split heads mode
|
|
||||||
# assert (
|
|
||||||
# self.base_num_heads % 2 == 0
|
|
||||||
# ), "Number of heads must be even for splitting"
|
|
||||||
self.heads_per_component = self.base_num_heads // 2
|
|
||||||
|
|
||||||
# Single projections
|
|
||||||
self.q_proj = nn.Linear(
|
|
||||||
self.hidden_size,
|
|
||||||
self.hidden_size,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.k_proj = nn.Linear(
|
|
||||||
self.hidden_size,
|
|
||||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Double projection mode
|
|
||||||
self.heads_per_component = self.base_num_heads
|
|
||||||
|
|
||||||
# Double-sized projections
|
|
||||||
self.q_proj = nn.Linear(
|
|
||||||
self.hidden_size,
|
|
||||||
self.hidden_size * 2,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.k_proj = nn.Linear(
|
|
||||||
self.hidden_size,
|
|
||||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Single V projection
|
|
||||||
self.v_proj = nn.Linear(
|
|
||||||
self.hidden_size,
|
|
||||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Output projection
|
|
||||||
self.o_proj = nn.Linear(
|
|
||||||
self.hidden_size,
|
|
||||||
self.hidden_size,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize differential attention parameters
|
|
||||||
self.lambda_init = nn.Parameter(
|
|
||||||
torch.full((), lambda_init_fn(self.layer_idx)),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
self.lambda_q1 = nn.Parameter(
|
|
||||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
|
||||||
)
|
|
||||||
self.lambda_k1 = nn.Parameter(
|
|
||||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
|
||||||
)
|
|
||||||
self.lambda_q2 = nn.Parameter(
|
|
||||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
|
||||||
)
|
|
||||||
self.lambda_k2 = nn.Parameter(
|
|
||||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
|
||||||
sublayer_norm = getattr(config, "sublayer_norm", True)
|
|
||||||
|
|
||||||
if self.split_heads:
|
|
||||||
subln_dim = 2 * self.head_dim
|
|
||||||
else:
|
|
||||||
subln_dim = self.head_dim
|
|
||||||
|
|
||||||
self.subln = (
|
|
||||||
LlamaRMSNorm(hidden_size=subln_dim, eps=1e-5)
|
|
||||||
if sublayer_norm
|
|
||||||
else nn.Identity()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False, # pylint: disable=unused-argument
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
) -> tuple[
|
|
||||||
torch.Tensor,
|
|
||||||
Optional[torch.Tensor],
|
|
||||||
Optional[tuple[torch.Tensor, torch.Tensor]],
|
|
||||||
]:
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
# Project to Q1,Q2 and K1,K2
|
|
||||||
qp = self.q_proj(hidden_states)
|
|
||||||
kp = self.k_proj(hidden_states)
|
|
||||||
v = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
# Split into Q1,Q2 and K1,K2
|
|
||||||
q1, q2 = qp.chunk(2, dim=-1)
|
|
||||||
k1, k2 = kp.chunk(2, dim=-1)
|
|
||||||
|
|
||||||
# Reshape Q1,Q2 for attention
|
|
||||||
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# Reshape K1,K2 for attention
|
|
||||||
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# Reshape V
|
|
||||||
if self.split_heads:
|
|
||||||
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
|
|
||||||
else:
|
|
||||||
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# Apply rotary embeddings
|
|
||||||
if position_embeddings is None:
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = torch.arange(q_len, device=q1.device)
|
|
||||||
cos, sin = self.rotary_emb(q1, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
|
||||||
|
|
||||||
if self.split_heads:
|
|
||||||
cos, _ = cos.chunk(2, dim=2)
|
|
||||||
sin, _ = sin.chunk(2, dim=2)
|
|
||||||
|
|
||||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
|
||||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
||||||
k = torch.stack([k1, k2], dim=1)
|
|
||||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
|
||||||
k1, k2 = k.unbind(dim=1)
|
|
||||||
|
|
||||||
# Repeat KV heads to match Q heads
|
|
||||||
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
|
||||||
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
|
||||||
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
|
||||||
|
|
||||||
# Calculate attention scores for both parts
|
|
||||||
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
|
||||||
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
|
|
||||||
attn1 = attn1 + causal_mask
|
|
||||||
attn2 = attn2 + causal_mask
|
|
||||||
|
|
||||||
# Apply softmax
|
|
||||||
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
|
|
||||||
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
|
|
||||||
|
|
||||||
# Apply dropout
|
|
||||||
attn1 = F.dropout(attn1, p=self.attention_dropout, training=self.training)
|
|
||||||
attn2 = F.dropout(attn2, p=self.attention_dropout, training=self.training)
|
|
||||||
|
|
||||||
# Calculate lambda
|
|
||||||
lambda_1 = torch.exp(
|
|
||||||
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
|
||||||
).type_as(q1)
|
|
||||||
lambda_2 = torch.exp(
|
|
||||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
|
||||||
).type_as(q1)
|
|
||||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
|
||||||
|
|
||||||
# Compute differential attention (following paper's formula)
|
|
||||||
attn_weights = attn1 - lambda_full * attn2
|
|
||||||
|
|
||||||
# Apply attention weights to values
|
|
||||||
attn = torch.matmul(attn_weights, v)
|
|
||||||
|
|
||||||
# Apply sublayer norm and scaling
|
|
||||||
attn = self.subln(attn)
|
|
||||||
attn = attn * (1 - self.lambda_init)
|
|
||||||
|
|
||||||
# Reshape to output
|
|
||||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
|
||||||
attn = self.o_proj(attn)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
return attn, attn_weights, past_key_value
|
|
||||||
return attn, None, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
|
|
||||||
"""Differential Attention implementation as described in the Diff Transformer paper.
|
|
||||||
This implements the same logic as `LlamaDifferentialAttention`, but uses
|
|
||||||
`scaled_dot_product_attention` instead of "manually" computing it under the hood.
|
|
||||||
|
|
||||||
This implements a modified attention mechanism that computes the difference between
|
|
||||||
two attention patterns, scaled by learned lambda parameters. The mechanism helps
|
|
||||||
reduce noise in the attention weights for irrelevant / less relevant tokens.
|
|
||||||
|
|
||||||
Key components:
|
|
||||||
- Split head dimension for differential computation
|
|
||||||
- Learned lambda parameters that control attention scaling
|
|
||||||
- Sublayer normalization on the attention output
|
|
||||||
|
|
||||||
See:
|
|
||||||
- https://arxiv.org/abs/2410.05258
|
|
||||||
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Model configuration object containing hidden size, number of heads etc.
|
|
||||||
layer_idx: Index of this layer in the transformer stack
|
|
||||||
dtype: Data type for the layer parameters
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size]
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
) -> tuple[
|
|
||||||
torch.Tensor,
|
|
||||||
Optional[torch.Tensor],
|
|
||||||
Optional[tuple[torch.Tensor, torch.Tensor]],
|
|
||||||
]:
|
|
||||||
if output_attentions:
|
|
||||||
transformers.logger.warning_once(
|
|
||||||
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
|
||||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
|
||||||
)
|
|
||||||
return super().forward(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
# Project to Q1,Q2 and K1,K2
|
|
||||||
qp = self.q_proj(hidden_states)
|
|
||||||
kp = self.k_proj(hidden_states)
|
|
||||||
v = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
# Split into Q1,Q2 and K1,K2
|
|
||||||
q1, q2 = qp.chunk(2, dim=-1)
|
|
||||||
k1, k2 = kp.chunk(2, dim=-1)
|
|
||||||
|
|
||||||
# Reshape Q1,Q2 for attention
|
|
||||||
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# Reshape K1,K2 for attention
|
|
||||||
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# Reshape V
|
|
||||||
if self.split_heads:
|
|
||||||
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
|
|
||||||
else:
|
|
||||||
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# Apply rotary embeddings
|
|
||||||
if position_embeddings is None:
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = torch.arange(q_len, device=q1.device)
|
|
||||||
cos, sin = self.rotary_emb(q1, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
|
||||||
|
|
||||||
if self.split_heads:
|
|
||||||
cos, _ = cos.chunk(2, dim=2)
|
|
||||||
sin, _ = sin.chunk(2, dim=2)
|
|
||||||
|
|
||||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
|
||||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
||||||
k = torch.stack([k1, k2], dim=1)
|
|
||||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
|
||||||
k1, k2 = k.unbind(dim=1)
|
|
||||||
|
|
||||||
# Repeat KV heads to match Q heads
|
|
||||||
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
|
||||||
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
|
||||||
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
|
||||||
|
|
||||||
causal_mask = None
|
|
||||||
if attention_mask is not None:
|
|
||||||
causal_mask = attention_mask
|
|
||||||
causal_mask = causal_mask[:, :, :, : k1.shape[-2]]
|
|
||||||
|
|
||||||
# SDPA with memory-efficient backend requires contiguous inputs on CUDA
|
|
||||||
if q1.device.type == "cuda" and causal_mask is not None:
|
|
||||||
q1, q2 = q1.contiguous(), q2.contiguous()
|
|
||||||
k1, k2 = k1.contiguous(), k2.contiguous()
|
|
||||||
v = v.contiguous()
|
|
||||||
|
|
||||||
# Calculate attention using SDPA
|
|
||||||
is_causal = attention_mask is None and q_len > 1
|
|
||||||
|
|
||||||
dropout_p = self.attention_dropout if self.training else 0.0
|
|
||||||
attn1 = F.scaled_dot_product_attention(
|
|
||||||
q1,
|
|
||||||
k1,
|
|
||||||
v,
|
|
||||||
attn_mask=causal_mask,
|
|
||||||
dropout_p=dropout_p,
|
|
||||||
is_causal=is_causal,
|
|
||||||
)
|
|
||||||
attn2 = F.scaled_dot_product_attention(
|
|
||||||
q2,
|
|
||||||
k2,
|
|
||||||
v,
|
|
||||||
attn_mask=causal_mask,
|
|
||||||
dropout_p=dropout_p,
|
|
||||||
is_causal=is_causal,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate lambda
|
|
||||||
lambda_1 = torch.exp(
|
|
||||||
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
|
||||||
).type_as(q1)
|
|
||||||
lambda_2 = torch.exp(
|
|
||||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
|
||||||
).type_as(q1)
|
|
||||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
|
||||||
|
|
||||||
# Combine the attention outputs
|
|
||||||
attn = attn1 - lambda_full * attn2
|
|
||||||
|
|
||||||
# Apply sublayer norm and scaling
|
|
||||||
attn = self.subln(attn)
|
|
||||||
attn = attn * (1 - self.lambda_init)
|
|
||||||
|
|
||||||
# Reshape to output
|
|
||||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
|
||||||
attn = self.o_proj(attn)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
return (
|
|
||||||
attn,
|
|
||||||
None,
|
|
||||||
past_key_value,
|
|
||||||
) # Note: can't return attn_weights with SDPA
|
|
||||||
return attn, None, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
|
|
||||||
"""Differential Attention implementation using Flash Attention 2.
|
|
||||||
This implements the same logic as `LlamaDifferentialAttention`, but uses
|
|
||||||
Flash Attention 2 for more efficient computation.
|
|
||||||
|
|
||||||
This implements a modified attention mechanism that computes the difference between
|
|
||||||
two attention patterns, scaled by learned lambda parameters. The mechanism helps
|
|
||||||
reduce noise in the attention weights for irrelevant / less relevant tokens.
|
|
||||||
|
|
||||||
Key components:
|
|
||||||
- Split head dimension for differential computation
|
|
||||||
- Learned lambda parameters that control attention scaling
|
|
||||||
- Sublayer normalization on the attention output
|
|
||||||
- Flash Attention 2 for efficient attention computation
|
|
||||||
|
|
||||||
See:
|
|
||||||
- https://arxiv.org/abs/2410.05258
|
|
||||||
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Model configuration object containing hidden size, number of heads etc.
|
|
||||||
layer_idx: Index of this layer in the transformer stack
|
|
||||||
dtype: Data type for the layer parameters
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size]
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> tuple[
|
|
||||||
torch.Tensor,
|
|
||||||
Optional[torch.Tensor],
|
|
||||||
Optional[tuple[torch.Tensor, torch.Tensor]],
|
|
||||||
]:
|
|
||||||
if output_attentions:
|
|
||||||
transformers.logger.warning_once(
|
|
||||||
"LlamaModel is using LlamaFlashAttention, but Flash Attention does not support `output_attentions=True`. "
|
|
||||||
"Falling back to the manual attention implementation."
|
|
||||||
)
|
|
||||||
return super().forward(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
# Project to Q1,Q2 and K1,K2
|
|
||||||
qp = self.q_proj(hidden_states)
|
|
||||||
kp = self.k_proj(hidden_states)
|
|
||||||
v = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
# Split into Q1,Q2 and K1,K2
|
|
||||||
q1, q2 = qp.chunk(2, dim=-1)
|
|
||||||
k1, k2 = kp.chunk(2, dim=-1)
|
|
||||||
|
|
||||||
# Reshape Q1,Q2 for attention
|
|
||||||
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# Reshape K1,K2 for attention
|
|
||||||
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# Reshape V
|
|
||||||
if self.split_heads:
|
|
||||||
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
|
|
||||||
else:
|
|
||||||
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# Apply rotary embeddings
|
|
||||||
if position_embeddings is None:
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = torch.arange(q_len, device=q1.device)
|
|
||||||
cos, sin = self.rotary_emb(q1, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
|
||||||
|
|
||||||
if self.split_heads:
|
|
||||||
cos, _ = cos.chunk(2, dim=2)
|
|
||||||
sin, _ = sin.chunk(2, dim=2)
|
|
||||||
|
|
||||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
|
||||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
||||||
k = torch.stack([k1, k2], dim=1)
|
|
||||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
|
||||||
k1, k2 = k.unbind(dim=1)
|
|
||||||
|
|
||||||
# Repeat KV heads to match Q heads
|
|
||||||
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
|
||||||
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
|
||||||
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
|
||||||
|
|
||||||
q1 = q1.transpose(1, 2)
|
|
||||||
q2 = q2.transpose(1, 2)
|
|
||||||
k1 = k1.transpose(1, 2)
|
|
||||||
k2 = k2.transpose(1, 2)
|
|
||||||
v = v.transpose(1, 2)
|
|
||||||
|
|
||||||
# Calculate attention using Flash Attention
|
|
||||||
dropout_p = self.attention_dropout if self.training else 0.0
|
|
||||||
if self.split_heads:
|
|
||||||
v1, v2 = v.chunk(2, dim=-1)
|
|
||||||
attn11 = flash_attn_func(
|
|
||||||
q1,
|
|
||||||
k1,
|
|
||||||
v1,
|
|
||||||
dropout_p=dropout_p,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
attn12 = flash_attn_func(
|
|
||||||
q1,
|
|
||||||
k1,
|
|
||||||
v2,
|
|
||||||
dropout_p=dropout_p,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
attn1 = torch.cat([attn11, attn12], dim=-1)
|
|
||||||
|
|
||||||
attn21 = flash_attn_func(
|
|
||||||
q2,
|
|
||||||
k2,
|
|
||||||
v1,
|
|
||||||
dropout_p=dropout_p,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
attn22 = flash_attn_func(
|
|
||||||
q2,
|
|
||||||
k2,
|
|
||||||
v2,
|
|
||||||
dropout_p=dropout_p,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
attn2 = torch.cat([attn21, attn22], dim=-1)
|
|
||||||
else:
|
|
||||||
attn1 = flash_attn_func(
|
|
||||||
q1,
|
|
||||||
k1,
|
|
||||||
v,
|
|
||||||
dropout_p=dropout_p,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
attn2 = flash_attn_func(
|
|
||||||
q2,
|
|
||||||
k2,
|
|
||||||
v,
|
|
||||||
dropout_p=dropout_p,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn1 = attn1.transpose(1, 2)
|
|
||||||
attn2 = attn2.transpose(1, 2)
|
|
||||||
|
|
||||||
# Calculate lambda
|
|
||||||
lambda_1 = torch.exp(
|
|
||||||
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
|
||||||
).type_as(q1)
|
|
||||||
lambda_2 = torch.exp(
|
|
||||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
|
||||||
).type_as(q1)
|
|
||||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
|
||||||
|
|
||||||
# Combine the attention outputs
|
|
||||||
attn = attn1 - lambda_full * attn2
|
|
||||||
|
|
||||||
# Apply sublayer norm and scaling
|
|
||||||
attn = self.subln(attn)
|
|
||||||
attn = attn * (1 - self.lambda_init)
|
|
||||||
|
|
||||||
# Reshape to output
|
|
||||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
|
||||||
attn = self.o_proj(attn)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
return (
|
|
||||||
attn,
|
|
||||||
None,
|
|
||||||
past_key_value,
|
|
||||||
) # Note: can't return attn_weights with Flash Attention
|
|
||||||
return attn, None, past_key_value
|
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
|
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
|
||||||
|
|
||||||
from axolotl.integrations.differential_transformer.differential_attention import (
|
from axolotl.integrations.diff_transformer.diff_attn import (
|
||||||
LlamaDifferentialAttention,
|
LlamaDifferentialAttention,
|
||||||
LlamaDifferentialFlashAttention2,
|
LlamaDifferentialFlashAttention2,
|
||||||
LlamaDifferentialSdpaAttention,
|
LlamaDifferentialSdpaAttention,
|
||||||
|
|||||||
@@ -714,7 +714,7 @@ class ModelLoader:
|
|||||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self.cfg.differential_attention:
|
if self.cfg.differentiaion:
|
||||||
self.model_kwargs[
|
self.model_kwargs[
|
||||||
"attn_implementation"
|
"attn_implementation"
|
||||||
] = "differential_flash_attention_2"
|
] = "differential_flash_attention_2"
|
||||||
@@ -727,7 +727,7 @@ class ModelLoader:
|
|||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
)
|
)
|
||||||
elif self.cfg.sdp_attention:
|
elif self.cfg.sdp_attention:
|
||||||
if self.cfg.differential_attention:
|
if self.cfg.diff_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "differential_sdpa"
|
self.model_kwargs["attn_implementation"] = "differential_sdpa"
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"differential_sdpa"
|
"differential_sdpa"
|
||||||
@@ -738,7 +738,7 @@ class ModelLoader:
|
|||||||
"sdpa"
|
"sdpa"
|
||||||
)
|
)
|
||||||
elif self.cfg.eager_attention:
|
elif self.cfg.eager_attention:
|
||||||
if self.cfg.differential_attention:
|
if self.cfg.diff_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "differential_eager"
|
self.model_kwargs["attn_implementation"] = "differential_eager"
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"differential_eager"
|
"differential_eager"
|
||||||
@@ -748,7 +748,7 @@ class ModelLoader:
|
|||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"eager"
|
"eager"
|
||||||
)
|
)
|
||||||
elif self.cfg.differential_attention:
|
elif self.cfg.diff_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "differential_eager"
|
self.model_kwargs["attn_implementation"] = "differential_eager"
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"differential_eager"
|
"differential_eager"
|
||||||
|
|||||||
151
src/axolotl/utils/yaml.py
Normal file
151
src/axolotl/utils/yaml.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
"""Utilities for YAML files."""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Dict, List, Set, Tuple, Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
class YAMLOrderTracker:
|
||||||
|
"""Tracks the order of keys and section breaks in YAML files."""
|
||||||
|
|
||||||
|
def __init__(self, yaml_path: str):
|
||||||
|
self.yaml_path = yaml_path
|
||||||
|
self.structure, self.needs_break = self._parse_yaml_structure()
|
||||||
|
|
||||||
|
def _get_indentation_level(self, line: str) -> int:
|
||||||
|
"""Get the indentation level of a line."""
|
||||||
|
return len(line) - len(line.lstrip())
|
||||||
|
|
||||||
|
def _parse_yaml_structure(
|
||||||
|
self,
|
||||||
|
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
|
||||||
|
"""Parse the YAML file to extract structure and identify section breaks."""
|
||||||
|
with open(self.yaml_path, "r", encoding="utf-8") as file:
|
||||||
|
contents = file.readlines()
|
||||||
|
|
||||||
|
structure: OrderedDict = OrderedDict()
|
||||||
|
needs_break = set() # Track which keys should have a break before them
|
||||||
|
current_path = []
|
||||||
|
last_indentation = -1
|
||||||
|
had_empty_line = False
|
||||||
|
|
||||||
|
for line in contents:
|
||||||
|
# Track empty lines and comments
|
||||||
|
if not line.strip() or line.strip().startswith("#"):
|
||||||
|
had_empty_line = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get indentation level and content
|
||||||
|
indentation = self._get_indentation_level(line)
|
||||||
|
content = line.strip()
|
||||||
|
|
||||||
|
# Skip lines that don't define keys
|
||||||
|
if ":" not in content:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract key
|
||||||
|
key = content.split(":")[0].strip()
|
||||||
|
|
||||||
|
# If this is a top-level key and we had an empty line, mark it
|
||||||
|
if indentation == 0:
|
||||||
|
if had_empty_line:
|
||||||
|
needs_break.add(key)
|
||||||
|
had_empty_line = False
|
||||||
|
|
||||||
|
# Handle indentation changes
|
||||||
|
if indentation > last_indentation:
|
||||||
|
current_path.append(key)
|
||||||
|
elif indentation < last_indentation:
|
||||||
|
levels_up = (last_indentation - indentation) // 2
|
||||||
|
current_path = current_path[:-levels_up]
|
||||||
|
current_path[-1] = key
|
||||||
|
else:
|
||||||
|
if current_path:
|
||||||
|
current_path[-1] = key
|
||||||
|
|
||||||
|
# Update structure
|
||||||
|
current_dict = structure
|
||||||
|
for path_key in current_path[:-1]:
|
||||||
|
if path_key not in current_dict:
|
||||||
|
current_dict[path_key] = OrderedDict()
|
||||||
|
current_dict = current_dict[path_key]
|
||||||
|
|
||||||
|
if current_path:
|
||||||
|
if current_path[-1] not in current_dict:
|
||||||
|
current_dict[current_path[-1]] = OrderedDict()
|
||||||
|
|
||||||
|
last_indentation = indentation
|
||||||
|
|
||||||
|
return structure, needs_break
|
||||||
|
|
||||||
|
|
||||||
|
class OrderedDumper(yaml.SafeDumper):
|
||||||
|
"""Custom YAML dumper that maintains dictionary order."""
|
||||||
|
|
||||||
|
|
||||||
|
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
|
||||||
|
"""Custom representer for dictionaries that maintains order."""
|
||||||
|
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
|
||||||
|
|
||||||
|
|
||||||
|
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
|
||||||
|
"""Reorder a dictionary based on a reference structure."""
|
||||||
|
ordered = OrderedDict()
|
||||||
|
|
||||||
|
# First add keys that are in the reference order
|
||||||
|
for key in reference_structure:
|
||||||
|
if key in data:
|
||||||
|
if isinstance(reference_structure[key], dict) and isinstance(
|
||||||
|
data[key], dict
|
||||||
|
):
|
||||||
|
ordered[key] = reorder_dict(data[key], reference_structure[key])
|
||||||
|
else:
|
||||||
|
ordered[key] = data[key]
|
||||||
|
|
||||||
|
# Then add any remaining keys that weren't in the reference
|
||||||
|
for key in data:
|
||||||
|
if key not in ordered:
|
||||||
|
ordered[key] = data[key]
|
||||||
|
|
||||||
|
return ordered
|
||||||
|
|
||||||
|
|
||||||
|
def dump_yaml_preserved_order(
|
||||||
|
data: Dict, reference_yaml_path: str, output_path: str
|
||||||
|
) -> None:
|
||||||
|
"""Dump YAML file while preserving nested order and normalized spacing."""
|
||||||
|
# Get reference structure and spacing
|
||||||
|
tracker = YAMLOrderTracker(reference_yaml_path)
|
||||||
|
|
||||||
|
# Reorder the data
|
||||||
|
ordered_data = reorder_dict(data, tracker.structure)
|
||||||
|
|
||||||
|
# Register the custom representer
|
||||||
|
OrderedDumper.add_representer(dict, ordered_dict_representer)
|
||||||
|
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
|
||||||
|
|
||||||
|
# First dump to string
|
||||||
|
yaml_str = yaml.dump(
|
||||||
|
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add spacing according to reference
|
||||||
|
lines = yaml_str.split("\n")
|
||||||
|
result_lines: List[str] = []
|
||||||
|
current_line = 0
|
||||||
|
|
||||||
|
while current_line < len(lines):
|
||||||
|
line = lines[current_line]
|
||||||
|
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
|
||||||
|
key = line.split(":")[0].strip()
|
||||||
|
if key in tracker.needs_break:
|
||||||
|
# Add single empty line before this key
|
||||||
|
if result_lines and result_lines[-1] != "":
|
||||||
|
result_lines.append("")
|
||||||
|
result_lines.append(line)
|
||||||
|
current_line += 1
|
||||||
|
|
||||||
|
# Write the final result
|
||||||
|
with open(output_path, "w", encoding="utf-8") as file:
|
||||||
|
file.write("\n".join(result_lines))
|
||||||
@@ -9,9 +9,6 @@ def base_config():
|
|||||||
"""Basic config for testing."""
|
"""Basic config for testing."""
|
||||||
return {
|
return {
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"plugins": [
|
|
||||||
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin",
|
|
||||||
],
|
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "axolotl-ai-co/alpaca_100_test",
|
"path": "axolotl-ai-co/alpaca_100_test",
|
||||||
@@ -8,9 +8,7 @@ from pytest import approx
|
|||||||
|
|
||||||
from axolotl.cli import load_cfg
|
from axolotl.cli import load_cfg
|
||||||
from axolotl.cli.evaluate import do_evaluate
|
from axolotl.cli.evaluate import do_evaluate
|
||||||
from axolotl.cli.integrations.convert_differential_transformer import (
|
from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer
|
||||||
convert_differential_transformer,
|
|
||||||
)
|
|
||||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs
|
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs
|
||||||
|
|
||||||
|
|
||||||
@@ -26,7 +24,7 @@ def test_conversion_and_eval_cli(tmp_path: Path, base_config):
|
|||||||
cli_args = ConvertDiffTransformerCliArgs(
|
cli_args = ConvertDiffTransformerCliArgs(
|
||||||
debug=True, zero_init=True, sublayer_norm=False
|
debug=True, zero_init=True, sublayer_norm=False
|
||||||
)
|
)
|
||||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||||
|
|
||||||
assert debug_info["generations_match"] is True
|
assert debug_info["generations_match"] is True
|
||||||
assert (output_dir / "model.safetensors").exists()
|
assert (output_dir / "model.safetensors").exists()
|
||||||
@@ -10,23 +10,19 @@ import pytest
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from axolotl.cli import load_cfg
|
from axolotl.cli import load_cfg
|
||||||
from axolotl.cli.integrations.convert_differential_transformer import (
|
from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer
|
||||||
convert_differential_transformer,
|
|
||||||
)
|
|
||||||
from axolotl.cli.main import cli
|
from axolotl.cli.main import cli
|
||||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs
|
from axolotl.common.cli import ConvertDiffTransformerCliArgs
|
||||||
|
|
||||||
|
|
||||||
def test_cli_validation(cli_runner):
|
def test_cli_validation(cli_runner):
|
||||||
# Test missing config file
|
# Test missing config file
|
||||||
result = cli_runner.invoke(cli, ["convert-differential-transformer"])
|
result = cli_runner.invoke(cli, ["convert-diff-transformer"])
|
||||||
assert result.exit_code != 0
|
assert result.exit_code != 0
|
||||||
assert "Error: Missing argument 'CONFIG'." in result.output
|
assert "Error: Missing argument 'CONFIG'." in result.output
|
||||||
|
|
||||||
# Test non-existent config file
|
# Test non-existent config file
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
|
||||||
cli, ["convert-differential-transformer", "nonexistent.yml"]
|
|
||||||
)
|
|
||||||
assert result.exit_code != 0
|
assert result.exit_code != 0
|
||||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||||
|
|
||||||
@@ -37,11 +33,9 @@ def test_basic_execution(cli_runner, tmp_path: Path, base_config):
|
|||||||
yaml.dump(base_config, file)
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"axolotl.cli.integrations.convert_differential_transformer.do_cli"
|
"axolotl.cli.integrations.convert_diff_transformer.do_cli"
|
||||||
) as mock_do_cli:
|
) as mock_do_cli:
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)])
|
||||||
cli, ["convert-differential-transformer", str(config_path)]
|
|
||||||
)
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
|
|
||||||
mock_do_cli.assert_called_once()
|
mock_do_cli.assert_called_once()
|
||||||
@@ -56,14 +50,9 @@ def test_conversion_cli_basic(tmp_path: Path, base_config):
|
|||||||
with open(config_path, "w", encoding="utf-8") as file:
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
yaml.dump(base_config, file)
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
# Load config the same way do_cli does
|
|
||||||
cfg = load_cfg(str(config_path))
|
cfg = load_cfg(str(config_path))
|
||||||
|
|
||||||
# Create CLI args
|
|
||||||
cli_args = ConvertDiffTransformerCliArgs()
|
cli_args = ConvertDiffTransformerCliArgs()
|
||||||
|
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||||
# Call convert_differential_transformer directly
|
|
||||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
|
||||||
|
|
||||||
assert not debug_info
|
assert not debug_info
|
||||||
assert (output_dir / "model.safetensors").exists()
|
assert (output_dir / "model.safetensors").exists()
|
||||||
@@ -79,14 +68,9 @@ def test_conversion_cli_debug(tmp_path: Path, base_config):
|
|||||||
with open(config_path, "w", encoding="utf-8") as file:
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
yaml.dump(base_config, file)
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
# Load config the same way do_cli does
|
|
||||||
cfg = load_cfg(str(config_path))
|
cfg = load_cfg(str(config_path))
|
||||||
|
|
||||||
# Create CLI args
|
|
||||||
cli_args = ConvertDiffTransformerCliArgs(debug=True)
|
cli_args = ConvertDiffTransformerCliArgs(debug=True)
|
||||||
|
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||||
# Call convert_differential_transformer directly
|
|
||||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
|
||||||
|
|
||||||
assert not debug_info["generations_match"]
|
assert not debug_info["generations_match"]
|
||||||
assert not debug_info["match_expected"]
|
assert not debug_info["match_expected"]
|
||||||
@@ -107,7 +91,7 @@ def test_conversion_cli_reproduce(tmp_path: Path, base_config):
|
|||||||
cli_args = ConvertDiffTransformerCliArgs(
|
cli_args = ConvertDiffTransformerCliArgs(
|
||||||
debug=True, zero_init=True, sublayer_norm=False
|
debug=True, zero_init=True, sublayer_norm=False
|
||||||
)
|
)
|
||||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||||
|
|
||||||
assert debug_info["generations_match"] is True
|
assert debug_info["generations_match"] is True
|
||||||
assert (output_dir / "model.safetensors").exists()
|
assert (output_dir / "model.safetensors").exists()
|
||||||
@@ -133,7 +117,7 @@ def test_conversion_cli_repoduce_attentions(
|
|||||||
cli_args = ConvertDiffTransformerCliArgs(
|
cli_args = ConvertDiffTransformerCliArgs(
|
||||||
debug=True, zero_init=True, sublayer_norm=False
|
debug=True, zero_init=True, sublayer_norm=False
|
||||||
)
|
)
|
||||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||||
|
|
||||||
assert debug_info["generations_match"] is True
|
assert debug_info["generations_match"] is True
|
||||||
assert (output_dir / "model.safetensors").exists()
|
assert (output_dir / "model.safetensors").exists()
|
||||||
@@ -155,7 +139,7 @@ def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str)
|
|||||||
|
|
||||||
cfg = load_cfg(str(config_path))
|
cfg = load_cfg(str(config_path))
|
||||||
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
|
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
|
||||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||||
|
|
||||||
assert debug_info["generations_match"] is False
|
assert debug_info["generations_match"] is False
|
||||||
assert (output_dir / "model.safetensors").exists()
|
assert (output_dir / "model.safetensors").exists()
|
||||||
Reference in New Issue
Block a user