adding yaml dumper preserving input config format
This commit is contained in:
@@ -14,7 +14,8 @@ from transformers import HfArgumentParser
|
||||
|
||||
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.differential_transformer.convert import convert_to_diff_attn
|
||||
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attn
|
||||
from axolotl.utils.yaml import dump_yaml_preserved_order
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,7 +52,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
||||
raise
|
||||
|
||||
|
||||
def convert_differential_transformer(cfg, cli_args, config_path):
|
||||
def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
debug_info = {}
|
||||
|
||||
# Load model and tokenizer
|
||||
@@ -114,16 +115,23 @@ def convert_differential_transformer(cfg, cli_args, config_path):
|
||||
LOG.info("Saving updated config to %s", output_config_path)
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as file:
|
||||
data = yaml.safe_load(file) or {}
|
||||
modified_cfg = yaml.safe_load(file) or {}
|
||||
|
||||
data["base_model"] = cfg.output_dir
|
||||
data["differential_attention"] = True
|
||||
data["plugins"] = [
|
||||
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin"
|
||||
]
|
||||
modified_cfg["base_model"] = cfg.output_dir
|
||||
modified_cfg["diff_attention"] = True
|
||||
plugin_class = (
|
||||
"axolotl.integrations.diff_transformer.DifferentialTransformerPlugin"
|
||||
)
|
||||
if "plugins" in modified_cfg:
|
||||
modified_cfg["plugins"].append(plugin_class)
|
||||
else:
|
||||
modified_cfg["plugins"] = [plugin_class]
|
||||
|
||||
with open(output_config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(data, file)
|
||||
dump_yaml_preserved_order(
|
||||
data=modified_cfg,
|
||||
reference_yaml_path=config_path,
|
||||
output_path=output_config_path,
|
||||
)
|
||||
else:
|
||||
LOG.info("Not saving converted model to disk")
|
||||
LOG.info("Pass --output-dir path/to/save to save model")
|
||||
@@ -191,7 +199,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
|
||||
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
|
||||
convert_differential_transformer(cfg, cli_args, config)
|
||||
convert_diff_transformer(cfg, cli_args, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -252,11 +252,11 @@ def merge_lora(
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def convert_differential_transformer(config: str, **kwargs):
|
||||
def convert_diff_transformer(config: str, **kwargs):
|
||||
"""Convert model attention layers to differential attention layers."""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
from axolotl.cli.integrations.convert_differential_transformer import do_cli
|
||||
from axolotl.cli.integrations.convert_diff_transformer import do_cli
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class EvaluateCliArgs:
|
||||
@dataclass
|
||||
class ConvertDiffTransformerCliArgs:
|
||||
"""
|
||||
dataclass with arguments for convert-differential-transformer CLI
|
||||
dataclass with arguments for convert-diff-transformer CLI
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
|
||||
10
src/axolotl/integrations/diff_transformer/README.md
Normal file
10
src/axolotl/integrations/diff_transformer/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# Differential Transformer
|
||||
|
||||
### Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
|
||||
|
||||
diff_attention: true
|
||||
```
|
||||
@@ -13,11 +13,11 @@ class DifferentialTransformerPlugin(BasePlugin):
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.differential_transformer.args.DifferentialTransformerArgs"
|
||||
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""Apply differential attention patch before model loading if enabled."""
|
||||
if cfg.differential_attention:
|
||||
if cfg.diff_attention:
|
||||
from axolotl.monkeypatch.attention.differential import (
|
||||
patch_llama_attention_classes,
|
||||
)
|
||||
@@ -11,4 +11,4 @@ LOG = logging.getLogger(__name__)
|
||||
class DifferentialTransformerArgs(BaseModel):
|
||||
"""Input args for differential transformer."""
|
||||
|
||||
differential_attention: Optional[bool] = None
|
||||
diff_attention: Optional[bool] = None
|
||||
@@ -11,7 +11,7 @@ from transformers.models.llama.modeling_llama import (
|
||||
LlamaSdpaAttention,
|
||||
)
|
||||
|
||||
from .differential_attention import (
|
||||
from .diff_attn import (
|
||||
LlamaDifferentialAttention,
|
||||
LlamaDifferentialFlashAttention2,
|
||||
LlamaDifferentialSdpaAttention,
|
||||
375
src/axolotl/integrations/diff_transformer/diff_attn.py
Normal file
375
src/axolotl/integrations/diff_transformer/diff_attn.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""Re-implemention of differential attention."""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaRMSNorm,
|
||||
LlamaRotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||
batch_size, n_kv_heads, slen, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, None, :, :]
|
||||
.expand(batch_size, n_kv_heads, n_rep, slen, head_dim)
|
||||
.reshape(batch_size, n_kv_heads * n_rep, slen, head_dim)
|
||||
)
|
||||
|
||||
|
||||
def lambda_init_fn(depth):
|
||||
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
||||
|
||||
|
||||
class DifferentialAttentionBase(nn.Module):
|
||||
"""Base class for differential attention implementations."""
|
||||
|
||||
def __init__(self, config: Any, layer_idx: int):
|
||||
super().__init__()
|
||||
self._init_config(config, layer_idx)
|
||||
self._init_projections()
|
||||
self._init_differential_params()
|
||||
self._init_normalization(config)
|
||||
|
||||
def _init_config(self, config: Any, layer_idx: int):
|
||||
"""Initialize configuration parameters."""
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.base_num_heads = config.num_attention_heads
|
||||
self.base_num_kv_heads = config.num_key_value_heads
|
||||
self.layer_idx = layer_idx
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
self.split_heads = config.split_heads
|
||||
|
||||
if config.split_heads:
|
||||
# Split heads mode - single projections
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads // 2
|
||||
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
|
||||
# implementation, which asserts `self.base_num_heads` is even.
|
||||
self.heads_per_component = self.base_num_heads // 2
|
||||
self.value_head_dim = 2 * self.head_dim
|
||||
else:
|
||||
# Double projection mode
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.heads_per_component = self.base_num_heads
|
||||
self.value_head_dim = self.head_dim
|
||||
|
||||
def _init_projections(self):
|
||||
"""Initialize Q, K, V projections."""
|
||||
if self.split_heads:
|
||||
# Split heads mode - single projections
|
||||
q_out_dim = self.hidden_size
|
||||
k_out_dim = self.hidden_size // self.base_num_heads * self.base_num_kv_heads
|
||||
else:
|
||||
# Double projection mode
|
||||
q_out_dim = self.hidden_size * 2
|
||||
k_out_dim = (
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2
|
||||
)
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, q_out_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, k_out_dim, bias=False)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
||||
|
||||
def _init_differential_params(self):
|
||||
"""Initialize differential attention parameters."""
|
||||
self.lambda_init = nn.Parameter(
|
||||
torch.full((), lambda_init_fn(self.layer_idx)),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.lambda_q1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
self.lambda_k1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
self.lambda_q2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
self.lambda_k2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(
|
||||
self.max_position_embeddings, self.head_dim, self.rope_theta
|
||||
)
|
||||
|
||||
def _init_normalization(self, config):
|
||||
"""Initialize normalization layers."""
|
||||
sublayer_norm = getattr(config, "sublayer_norm", True)
|
||||
self.subln = (
|
||||
LlamaRMSNorm(self.value_head_dim, eps=1e-5)
|
||||
if sublayer_norm
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def _prepare_attention_inputs(self, hidden_states: torch.Tensor):
|
||||
"""Prepare inputs for attention computation."""
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Project and split
|
||||
qp = self.q_proj(hidden_states)
|
||||
kp = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
q1, q2 = qp.chunk(2, dim=-1)
|
||||
k1, k2 = kp.chunk(2, dim=-1)
|
||||
|
||||
# Reshape
|
||||
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
v = v.view(bsz, q_len, -1, self.value_head_dim).transpose(1, 2)
|
||||
|
||||
return q1, q2, k1, k2, v
|
||||
|
||||
def _apply_rotary_embeddings(
|
||||
self, q1, q2, k1, k2, position_ids, position_embeddings
|
||||
):
|
||||
"""Apply rotary embeddings to queries and keys."""
|
||||
if position_embeddings is None:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(q1.size(-2), device=q1.device)
|
||||
cos, sin = self.rotary_emb(q1, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
if self.split_heads:
|
||||
cos, _ = cos.chunk(2, dim=2)
|
||||
sin, _ = sin.chunk(2, dim=2)
|
||||
|
||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||
|
||||
return q1, q2, k1, k2, cos, sin
|
||||
|
||||
def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs):
|
||||
"""Handle caching for autoregressive generation."""
|
||||
if past_key_value is not None:
|
||||
k = torch.stack([k1, k2], dim=1)
|
||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||
k1, k2 = k.unbind(dim=1)
|
||||
|
||||
# Repeat KV heads
|
||||
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
||||
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
||||
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
||||
|
||||
return k1, k2, v
|
||||
|
||||
def _compute_lambda(self, q1):
|
||||
"""Compute lambda values for differential attention."""
|
||||
lambda_1 = torch.exp(
|
||||
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_2 = torch.exp(
|
||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||
).type_as(q1)
|
||||
return lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
def _process_attention_output(self, attn, bsz, q_len):
|
||||
"""Process and project attention output."""
|
||||
attn = self.subln(attn)
|
||||
attn = attn * (1 - self.lambda_init)
|
||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
||||
return self.o_proj(attn)
|
||||
|
||||
|
||||
class LlamaDifferentialAttention(DifferentialAttentionBase):
|
||||
"""Standard implementation of differential attention."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False, # pylint: disable=unused-argument
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
|
||||
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
|
||||
q1, q2, k1, k2, position_ids, position_embeddings
|
||||
)
|
||||
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
|
||||
|
||||
# Standard attention computation
|
||||
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
||||
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
|
||||
attn1 = attn1 + causal_mask
|
||||
attn2 = attn2 + causal_mask
|
||||
|
||||
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
|
||||
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
|
||||
|
||||
dropout_p = self.attention_dropout if self.training else 0.0
|
||||
attn1 = F.dropout(attn1, p=dropout_p, training=self.training)
|
||||
attn2 = F.dropout(attn2, p=dropout_p, training=self.training)
|
||||
|
||||
lambda_full = self._compute_lambda(q1)
|
||||
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
|
||||
|
||||
attn = self._process_attention_output(attn, bsz, q_len)
|
||||
|
||||
if output_attentions:
|
||||
return attn, attn1 - lambda_full * attn2, past_key_value
|
||||
return attn, None, past_key_value
|
||||
|
||||
|
||||
class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
|
||||
"""SDPA-based implementation of differential attention."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if output_attentions:
|
||||
return LlamaDifferentialAttention.forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
|
||||
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
|
||||
q1, q2, k1, k2, position_ids, position_embeddings
|
||||
)
|
||||
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
|
||||
|
||||
# SDPA-specific attention computation
|
||||
causal_mask = (
|
||||
None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]]
|
||||
)
|
||||
is_causal = attention_mask is None and q_len > 1
|
||||
dropout_p = self.attention_dropout if self.training else 0.0
|
||||
|
||||
if q1.device.type == "cuda" and causal_mask is not None:
|
||||
q1, q2 = q1.contiguous(), q2.contiguous()
|
||||
k1, k2 = k1.contiguous(), k2.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
attn1 = F.scaled_dot_product_attention(
|
||||
q1, k1, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
attn2 = F.scaled_dot_product_attention(
|
||||
q2, k2, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
|
||||
lambda_full = self._compute_lambda(q1)
|
||||
attn = attn1 - lambda_full * attn2
|
||||
|
||||
attn = self._process_attention_output(attn, bsz, q_len)
|
||||
return attn, None, past_key_value
|
||||
|
||||
|
||||
class LlamaDifferentialFlashAttention2(DifferentialAttentionBase):
|
||||
"""Flash Attention 2-based implementation of differential attention."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if output_attentions:
|
||||
return LlamaDifferentialAttention.forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
|
||||
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
|
||||
q1, q2, k1, k2, position_ids, position_embeddings
|
||||
)
|
||||
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
|
||||
|
||||
# Flash Attention specific processing
|
||||
q1, q2 = q1.transpose(1, 2), q2.transpose(1, 2)
|
||||
k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
dropout_p = self.attention_dropout if self.training else 0.0
|
||||
|
||||
if self.split_heads:
|
||||
v1, v2 = v.chunk(2, dim=-1)
|
||||
attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True)
|
||||
attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True)
|
||||
attn1 = torch.cat([attn11, attn12], dim=-1)
|
||||
|
||||
attn21 = flash_attn_func(q2, k2, v1, dropout_p=dropout_p, causal=True)
|
||||
attn22 = flash_attn_func(q2, k2, v2, dropout_p=dropout_p, causal=True)
|
||||
attn2 = torch.cat([attn21, attn22], dim=-1)
|
||||
else:
|
||||
attn1 = flash_attn_func(q1, k1, v, dropout_p=dropout_p, causal=True)
|
||||
attn2 = flash_attn_func(q2, k2, v, dropout_p=dropout_p, causal=True)
|
||||
|
||||
attn1, attn2 = attn1.transpose(1, 2), attn2.transpose(1, 2)
|
||||
|
||||
lambda_full = self._compute_lambda(q1)
|
||||
attn = attn1 - lambda_full * attn2
|
||||
|
||||
attn = self._process_attention_output(attn, bsz, q_len)
|
||||
return attn, None, past_key_value
|
||||
@@ -1,10 +0,0 @@
|
||||
# Differential Transformer
|
||||
|
||||
### Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.differential_transformer.DifferentialTransformerPlugin
|
||||
|
||||
differential_attention: true
|
||||
```
|
||||
@@ -1,641 +0,0 @@
|
||||
"""Re-implemention of differential attention."""
|
||||
# pylint: disable=invalid-name
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaRMSNorm,
|
||||
LlamaRotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||
batch_size, n_kv_heads, slen, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, None, :, :]
|
||||
.expand(batch_size, n_kv_heads, n_rep, slen, head_dim)
|
||||
.reshape(batch_size, n_kv_heads * n_rep, slen, head_dim)
|
||||
)
|
||||
|
||||
|
||||
def lambda_init_fn(depth):
|
||||
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
||||
|
||||
|
||||
class LlamaDifferentialAttention(nn.Module):
|
||||
"""Differential Attention implementation as described in the Diff Transformer paper.
|
||||
|
||||
This implements a modified attention mechanism that computes the difference between
|
||||
two attention patterns, scaled by learned lambda parameters. The mechanism helps
|
||||
reduce noise in the attention weights for irrelevant / less relevant tokens.
|
||||
|
||||
Key components:
|
||||
- Split head dimension for differential computation
|
||||
- Learned lambda parameters that control attention scaling
|
||||
- Sublayer normalization on the attention output
|
||||
|
||||
See:
|
||||
- https://arxiv.org/abs/2410.05258
|
||||
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
|
||||
|
||||
Args:
|
||||
config: Model configuration object containing hidden size, number of heads etc.
|
||||
layer_idx: Index of this layer in the transformer stack
|
||||
dtype: Data type for the layer parameters
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Any,
|
||||
layer_idx: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Base model config
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.base_num_heads = config.num_attention_heads
|
||||
self.base_num_kv_heads = config.num_key_value_heads
|
||||
|
||||
if config.split_heads:
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads // 2
|
||||
else:
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
self.split_heads = config.split_heads
|
||||
|
||||
if config.split_heads:
|
||||
# Split heads mode
|
||||
# assert (
|
||||
# self.base_num_heads % 2 == 0
|
||||
# ), "Number of heads must be even for splitting"
|
||||
self.heads_per_component = self.base_num_heads // 2
|
||||
|
||||
# Single projections
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
# Double projection mode
|
||||
self.heads_per_component = self.base_num_heads
|
||||
|
||||
# Double-sized projections
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size * 2,
|
||||
bias=False,
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Single V projection
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Output projection
|
||||
self.o_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Initialize differential attention parameters
|
||||
self.lambda_init = nn.Parameter(
|
||||
torch.full((), lambda_init_fn(self.layer_idx)),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.lambda_q1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
self.lambda_k1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
self.lambda_q2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
self.lambda_k2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||
sublayer_norm = getattr(config, "sublayer_norm", True)
|
||||
|
||||
if self.split_heads:
|
||||
subln_dim = 2 * self.head_dim
|
||||
else:
|
||||
subln_dim = self.head_dim
|
||||
|
||||
self.subln = (
|
||||
LlamaRMSNorm(hidden_size=subln_dim, eps=1e-5)
|
||||
if sublayer_norm
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False, # pylint: disable=unused-argument
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[tuple[torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Project to Q1,Q2 and K1,K2
|
||||
qp = self.q_proj(hidden_states)
|
||||
kp = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
# Split into Q1,Q2 and K1,K2
|
||||
q1, q2 = qp.chunk(2, dim=-1)
|
||||
k1, k2 = kp.chunk(2, dim=-1)
|
||||
|
||||
# Reshape Q1,Q2 for attention
|
||||
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Reshape K1,K2 for attention
|
||||
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Reshape V
|
||||
if self.split_heads:
|
||||
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
|
||||
else:
|
||||
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary embeddings
|
||||
if position_embeddings is None:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(q_len, device=q1.device)
|
||||
cos, sin = self.rotary_emb(q1, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
if self.split_heads:
|
||||
cos, _ = cos.chunk(2, dim=2)
|
||||
sin, _ = sin.chunk(2, dim=2)
|
||||
|
||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
k = torch.stack([k1, k2], dim=1)
|
||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||
k1, k2 = k.unbind(dim=1)
|
||||
|
||||
# Repeat KV heads to match Q heads
|
||||
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
||||
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
||||
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
||||
|
||||
# Calculate attention scores for both parts
|
||||
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
||||
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
|
||||
attn1 = attn1 + causal_mask
|
||||
attn2 = attn2 + causal_mask
|
||||
|
||||
# Apply softmax
|
||||
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
|
||||
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
|
||||
|
||||
# Apply dropout
|
||||
attn1 = F.dropout(attn1, p=self.attention_dropout, training=self.training)
|
||||
attn2 = F.dropout(attn2, p=self.attention_dropout, training=self.training)
|
||||
|
||||
# Calculate lambda
|
||||
lambda_1 = torch.exp(
|
||||
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_2 = torch.exp(
|
||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
# Compute differential attention (following paper's formula)
|
||||
attn_weights = attn1 - lambda_full * attn2
|
||||
|
||||
# Apply attention weights to values
|
||||
attn = torch.matmul(attn_weights, v)
|
||||
|
||||
# Apply sublayer norm and scaling
|
||||
attn = self.subln(attn)
|
||||
attn = attn * (1 - self.lambda_init)
|
||||
|
||||
# Reshape to output
|
||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
||||
attn = self.o_proj(attn)
|
||||
|
||||
if output_attentions:
|
||||
return attn, attn_weights, past_key_value
|
||||
return attn, None, past_key_value
|
||||
|
||||
|
||||
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
|
||||
"""Differential Attention implementation as described in the Diff Transformer paper.
|
||||
This implements the same logic as `LlamaDifferentialAttention`, but uses
|
||||
`scaled_dot_product_attention` instead of "manually" computing it under the hood.
|
||||
|
||||
This implements a modified attention mechanism that computes the difference between
|
||||
two attention patterns, scaled by learned lambda parameters. The mechanism helps
|
||||
reduce noise in the attention weights for irrelevant / less relevant tokens.
|
||||
|
||||
Key components:
|
||||
- Split head dimension for differential computation
|
||||
- Learned lambda parameters that control attention scaling
|
||||
- Sublayer normalization on the attention output
|
||||
|
||||
See:
|
||||
- https://arxiv.org/abs/2410.05258
|
||||
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
|
||||
|
||||
Args:
|
||||
config: Model configuration object containing hidden size, number of heads etc.
|
||||
layer_idx: Index of this layer in the transformer stack
|
||||
dtype: Data type for the layer parameters
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size]
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[tuple[torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
if output_attentions:
|
||||
transformers.logger.warning_once(
|
||||
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Project to Q1,Q2 and K1,K2
|
||||
qp = self.q_proj(hidden_states)
|
||||
kp = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
# Split into Q1,Q2 and K1,K2
|
||||
q1, q2 = qp.chunk(2, dim=-1)
|
||||
k1, k2 = kp.chunk(2, dim=-1)
|
||||
|
||||
# Reshape Q1,Q2 for attention
|
||||
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Reshape K1,K2 for attention
|
||||
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Reshape V
|
||||
if self.split_heads:
|
||||
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
|
||||
else:
|
||||
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary embeddings
|
||||
if position_embeddings is None:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(q_len, device=q1.device)
|
||||
cos, sin = self.rotary_emb(q1, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
if self.split_heads:
|
||||
cos, _ = cos.chunk(2, dim=2)
|
||||
sin, _ = sin.chunk(2, dim=2)
|
||||
|
||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
k = torch.stack([k1, k2], dim=1)
|
||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||
k1, k2 = k.unbind(dim=1)
|
||||
|
||||
# Repeat KV heads to match Q heads
|
||||
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
||||
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
||||
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
||||
|
||||
causal_mask = None
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask
|
||||
causal_mask = causal_mask[:, :, :, : k1.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend requires contiguous inputs on CUDA
|
||||
if q1.device.type == "cuda" and causal_mask is not None:
|
||||
q1, q2 = q1.contiguous(), q2.contiguous()
|
||||
k1, k2 = k1.contiguous(), k2.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
# Calculate attention using SDPA
|
||||
is_causal = attention_mask is None and q_len > 1
|
||||
|
||||
dropout_p = self.attention_dropout if self.training else 0.0
|
||||
attn1 = F.scaled_dot_product_attention(
|
||||
q1,
|
||||
k1,
|
||||
v,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
attn2 = F.scaled_dot_product_attention(
|
||||
q2,
|
||||
k2,
|
||||
v,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
# Calculate lambda
|
||||
lambda_1 = torch.exp(
|
||||
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_2 = torch.exp(
|
||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
# Combine the attention outputs
|
||||
attn = attn1 - lambda_full * attn2
|
||||
|
||||
# Apply sublayer norm and scaling
|
||||
attn = self.subln(attn)
|
||||
attn = attn * (1 - self.lambda_init)
|
||||
|
||||
# Reshape to output
|
||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
||||
attn = self.o_proj(attn)
|
||||
|
||||
if output_attentions:
|
||||
return (
|
||||
attn,
|
||||
None,
|
||||
past_key_value,
|
||||
) # Note: can't return attn_weights with SDPA
|
||||
return attn, None, past_key_value
|
||||
|
||||
|
||||
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
|
||||
"""Differential Attention implementation using Flash Attention 2.
|
||||
This implements the same logic as `LlamaDifferentialAttention`, but uses
|
||||
Flash Attention 2 for more efficient computation.
|
||||
|
||||
This implements a modified attention mechanism that computes the difference between
|
||||
two attention patterns, scaled by learned lambda parameters. The mechanism helps
|
||||
reduce noise in the attention weights for irrelevant / less relevant tokens.
|
||||
|
||||
Key components:
|
||||
- Split head dimension for differential computation
|
||||
- Learned lambda parameters that control attention scaling
|
||||
- Sublayer normalization on the attention output
|
||||
- Flash Attention 2 for efficient attention computation
|
||||
|
||||
See:
|
||||
- https://arxiv.org/abs/2410.05258
|
||||
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
|
||||
|
||||
Args:
|
||||
config: Model configuration object containing hidden size, number of heads etc.
|
||||
layer_idx: Index of this layer in the transformer stack
|
||||
dtype: Data type for the layer parameters
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size]
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[tuple[torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
if output_attentions:
|
||||
transformers.logger.warning_once(
|
||||
"LlamaModel is using LlamaFlashAttention, but Flash Attention does not support `output_attentions=True`. "
|
||||
"Falling back to the manual attention implementation."
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Project to Q1,Q2 and K1,K2
|
||||
qp = self.q_proj(hidden_states)
|
||||
kp = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
# Split into Q1,Q2 and K1,K2
|
||||
q1, q2 = qp.chunk(2, dim=-1)
|
||||
k1, k2 = kp.chunk(2, dim=-1)
|
||||
|
||||
# Reshape Q1,Q2 for attention
|
||||
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Reshape K1,K2 for attention
|
||||
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Reshape V
|
||||
if self.split_heads:
|
||||
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
|
||||
else:
|
||||
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary embeddings
|
||||
if position_embeddings is None:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(q_len, device=q1.device)
|
||||
cos, sin = self.rotary_emb(q1, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
if self.split_heads:
|
||||
cos, _ = cos.chunk(2, dim=2)
|
||||
sin, _ = sin.chunk(2, dim=2)
|
||||
|
||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
k = torch.stack([k1, k2], dim=1)
|
||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||
k1, k2 = k.unbind(dim=1)
|
||||
|
||||
# Repeat KV heads to match Q heads
|
||||
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
||||
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
||||
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
||||
|
||||
q1 = q1.transpose(1, 2)
|
||||
q2 = q2.transpose(1, 2)
|
||||
k1 = k1.transpose(1, 2)
|
||||
k2 = k2.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# Calculate attention using Flash Attention
|
||||
dropout_p = self.attention_dropout if self.training else 0.0
|
||||
if self.split_heads:
|
||||
v1, v2 = v.chunk(2, dim=-1)
|
||||
attn11 = flash_attn_func(
|
||||
q1,
|
||||
k1,
|
||||
v1,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
attn12 = flash_attn_func(
|
||||
q1,
|
||||
k1,
|
||||
v2,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
attn1 = torch.cat([attn11, attn12], dim=-1)
|
||||
|
||||
attn21 = flash_attn_func(
|
||||
q2,
|
||||
k2,
|
||||
v1,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
attn22 = flash_attn_func(
|
||||
q2,
|
||||
k2,
|
||||
v2,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
attn2 = torch.cat([attn21, attn22], dim=-1)
|
||||
else:
|
||||
attn1 = flash_attn_func(
|
||||
q1,
|
||||
k1,
|
||||
v,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
attn2 = flash_attn_func(
|
||||
q2,
|
||||
k2,
|
||||
v,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
attn1 = attn1.transpose(1, 2)
|
||||
attn2 = attn2.transpose(1, 2)
|
||||
|
||||
# Calculate lambda
|
||||
lambda_1 = torch.exp(
|
||||
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_2 = torch.exp(
|
||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
# Combine the attention outputs
|
||||
attn = attn1 - lambda_full * attn2
|
||||
|
||||
# Apply sublayer norm and scaling
|
||||
attn = self.subln(attn)
|
||||
attn = attn * (1 - self.lambda_init)
|
||||
|
||||
# Reshape to output
|
||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
||||
attn = self.o_proj(attn)
|
||||
|
||||
if output_attentions:
|
||||
return (
|
||||
attn,
|
||||
None,
|
||||
past_key_value,
|
||||
) # Note: can't return attn_weights with Flash Attention
|
||||
return attn, None, past_key_value
|
||||
@@ -3,7 +3,7 @@
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
|
||||
|
||||
from axolotl.integrations.differential_transformer.differential_attention import (
|
||||
from axolotl.integrations.diff_transformer.diff_attn import (
|
||||
LlamaDifferentialAttention,
|
||||
LlamaDifferentialFlashAttention2,
|
||||
LlamaDifferentialSdpaAttention,
|
||||
|
||||
@@ -714,7 +714,7 @@ class ModelLoader:
|
||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
pass
|
||||
|
||||
if self.cfg.differential_attention:
|
||||
if self.cfg.differentiaion:
|
||||
self.model_kwargs[
|
||||
"attn_implementation"
|
||||
] = "differential_flash_attention_2"
|
||||
@@ -727,7 +727,7 @@ class ModelLoader:
|
||||
"flash_attention_2"
|
||||
)
|
||||
elif self.cfg.sdp_attention:
|
||||
if self.cfg.differential_attention:
|
||||
if self.cfg.diff_attention:
|
||||
self.model_kwargs["attn_implementation"] = "differential_sdpa"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"differential_sdpa"
|
||||
@@ -738,7 +738,7 @@ class ModelLoader:
|
||||
"sdpa"
|
||||
)
|
||||
elif self.cfg.eager_attention:
|
||||
if self.cfg.differential_attention:
|
||||
if self.cfg.diff_attention:
|
||||
self.model_kwargs["attn_implementation"] = "differential_eager"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"differential_eager"
|
||||
@@ -748,7 +748,7 @@ class ModelLoader:
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"eager"
|
||||
)
|
||||
elif self.cfg.differential_attention:
|
||||
elif self.cfg.diff_attention:
|
||||
self.model_kwargs["attn_implementation"] = "differential_eager"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"differential_eager"
|
||||
|
||||
151
src/axolotl/utils/yaml.py
Normal file
151
src/axolotl/utils/yaml.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Utilities for YAML files."""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Set, Tuple, Union
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class YAMLOrderTracker:
|
||||
"""Tracks the order of keys and section breaks in YAML files."""
|
||||
|
||||
def __init__(self, yaml_path: str):
|
||||
self.yaml_path = yaml_path
|
||||
self.structure, self.needs_break = self._parse_yaml_structure()
|
||||
|
||||
def _get_indentation_level(self, line: str) -> int:
|
||||
"""Get the indentation level of a line."""
|
||||
return len(line) - len(line.lstrip())
|
||||
|
||||
def _parse_yaml_structure(
|
||||
self,
|
||||
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
|
||||
"""Parse the YAML file to extract structure and identify section breaks."""
|
||||
with open(self.yaml_path, "r", encoding="utf-8") as file:
|
||||
contents = file.readlines()
|
||||
|
||||
structure: OrderedDict = OrderedDict()
|
||||
needs_break = set() # Track which keys should have a break before them
|
||||
current_path = []
|
||||
last_indentation = -1
|
||||
had_empty_line = False
|
||||
|
||||
for line in contents:
|
||||
# Track empty lines and comments
|
||||
if not line.strip() or line.strip().startswith("#"):
|
||||
had_empty_line = True
|
||||
continue
|
||||
|
||||
# Get indentation level and content
|
||||
indentation = self._get_indentation_level(line)
|
||||
content = line.strip()
|
||||
|
||||
# Skip lines that don't define keys
|
||||
if ":" not in content:
|
||||
continue
|
||||
|
||||
# Extract key
|
||||
key = content.split(":")[0].strip()
|
||||
|
||||
# If this is a top-level key and we had an empty line, mark it
|
||||
if indentation == 0:
|
||||
if had_empty_line:
|
||||
needs_break.add(key)
|
||||
had_empty_line = False
|
||||
|
||||
# Handle indentation changes
|
||||
if indentation > last_indentation:
|
||||
current_path.append(key)
|
||||
elif indentation < last_indentation:
|
||||
levels_up = (last_indentation - indentation) // 2
|
||||
current_path = current_path[:-levels_up]
|
||||
current_path[-1] = key
|
||||
else:
|
||||
if current_path:
|
||||
current_path[-1] = key
|
||||
|
||||
# Update structure
|
||||
current_dict = structure
|
||||
for path_key in current_path[:-1]:
|
||||
if path_key not in current_dict:
|
||||
current_dict[path_key] = OrderedDict()
|
||||
current_dict = current_dict[path_key]
|
||||
|
||||
if current_path:
|
||||
if current_path[-1] not in current_dict:
|
||||
current_dict[current_path[-1]] = OrderedDict()
|
||||
|
||||
last_indentation = indentation
|
||||
|
||||
return structure, needs_break
|
||||
|
||||
|
||||
class OrderedDumper(yaml.SafeDumper):
|
||||
"""Custom YAML dumper that maintains dictionary order."""
|
||||
|
||||
|
||||
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
|
||||
"""Custom representer for dictionaries that maintains order."""
|
||||
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
|
||||
|
||||
|
||||
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
|
||||
"""Reorder a dictionary based on a reference structure."""
|
||||
ordered = OrderedDict()
|
||||
|
||||
# First add keys that are in the reference order
|
||||
for key in reference_structure:
|
||||
if key in data:
|
||||
if isinstance(reference_structure[key], dict) and isinstance(
|
||||
data[key], dict
|
||||
):
|
||||
ordered[key] = reorder_dict(data[key], reference_structure[key])
|
||||
else:
|
||||
ordered[key] = data[key]
|
||||
|
||||
# Then add any remaining keys that weren't in the reference
|
||||
for key in data:
|
||||
if key not in ordered:
|
||||
ordered[key] = data[key]
|
||||
|
||||
return ordered
|
||||
|
||||
|
||||
def dump_yaml_preserved_order(
|
||||
data: Dict, reference_yaml_path: str, output_path: str
|
||||
) -> None:
|
||||
"""Dump YAML file while preserving nested order and normalized spacing."""
|
||||
# Get reference structure and spacing
|
||||
tracker = YAMLOrderTracker(reference_yaml_path)
|
||||
|
||||
# Reorder the data
|
||||
ordered_data = reorder_dict(data, tracker.structure)
|
||||
|
||||
# Register the custom representer
|
||||
OrderedDumper.add_representer(dict, ordered_dict_representer)
|
||||
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
|
||||
|
||||
# First dump to string
|
||||
yaml_str = yaml.dump(
|
||||
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
|
||||
)
|
||||
|
||||
# Add spacing according to reference
|
||||
lines = yaml_str.split("\n")
|
||||
result_lines: List[str] = []
|
||||
current_line = 0
|
||||
|
||||
while current_line < len(lines):
|
||||
line = lines[current_line]
|
||||
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
|
||||
key = line.split(":")[0].strip()
|
||||
if key in tracker.needs_break:
|
||||
# Add single empty line before this key
|
||||
if result_lines and result_lines[-1] != "":
|
||||
result_lines.append("")
|
||||
result_lines.append(line)
|
||||
current_line += 1
|
||||
|
||||
# Write the final result
|
||||
with open(output_path, "w", encoding="utf-8") as file:
|
||||
file.write("\n".join(result_lines))
|
||||
@@ -9,9 +9,6 @@ def base_config():
|
||||
"""Basic config for testing."""
|
||||
return {
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"plugins": [
|
||||
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin",
|
||||
],
|
||||
"datasets": [
|
||||
{
|
||||
"path": "axolotl-ai-co/alpaca_100_test",
|
||||
@@ -8,9 +8,7 @@ from pytest import approx
|
||||
|
||||
from axolotl.cli import load_cfg
|
||||
from axolotl.cli.evaluate import do_evaluate
|
||||
from axolotl.cli.integrations.convert_differential_transformer import (
|
||||
convert_differential_transformer,
|
||||
)
|
||||
from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer
|
||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs
|
||||
|
||||
|
||||
@@ -26,7 +24,7 @@ def test_conversion_and_eval_cli(tmp_path: Path, base_config):
|
||||
cli_args = ConvertDiffTransformerCliArgs(
|
||||
debug=True, zero_init=True, sublayer_norm=False
|
||||
)
|
||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is True
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
@@ -10,23 +10,19 @@ import pytest
|
||||
import yaml
|
||||
|
||||
from axolotl.cli import load_cfg
|
||||
from axolotl.cli.integrations.convert_differential_transformer import (
|
||||
convert_differential_transformer,
|
||||
)
|
||||
from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer
|
||||
from axolotl.cli.main import cli
|
||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs
|
||||
|
||||
|
||||
def test_cli_validation(cli_runner):
|
||||
# Test missing config file
|
||||
result = cli_runner.invoke(cli, ["convert-differential-transformer"])
|
||||
result = cli_runner.invoke(cli, ["convert-diff-transformer"])
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Missing argument 'CONFIG'." in result.output
|
||||
|
||||
# Test non-existent config file
|
||||
result = cli_runner.invoke(
|
||||
cli, ["convert-differential-transformer", "nonexistent.yml"]
|
||||
)
|
||||
result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||
|
||||
@@ -37,11 +33,9 @@ def test_basic_execution(cli_runner, tmp_path: Path, base_config):
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
with patch(
|
||||
"axolotl.cli.integrations.convert_differential_transformer.do_cli"
|
||||
"axolotl.cli.integrations.convert_diff_transformer.do_cli"
|
||||
) as mock_do_cli:
|
||||
result = cli_runner.invoke(
|
||||
cli, ["convert-differential-transformer", str(config_path)]
|
||||
)
|
||||
result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)])
|
||||
assert result.exit_code == 0
|
||||
|
||||
mock_do_cli.assert_called_once()
|
||||
@@ -56,14 +50,9 @@ def test_conversion_cli_basic(tmp_path: Path, base_config):
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
# Load config the same way do_cli does
|
||||
cfg = load_cfg(str(config_path))
|
||||
|
||||
# Create CLI args
|
||||
cli_args = ConvertDiffTransformerCliArgs()
|
||||
|
||||
# Call convert_differential_transformer directly
|
||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert not debug_info
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
@@ -79,14 +68,9 @@ def test_conversion_cli_debug(tmp_path: Path, base_config):
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
# Load config the same way do_cli does
|
||||
cfg = load_cfg(str(config_path))
|
||||
|
||||
# Create CLI args
|
||||
cli_args = ConvertDiffTransformerCliArgs(debug=True)
|
||||
|
||||
# Call convert_differential_transformer directly
|
||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert not debug_info["generations_match"]
|
||||
assert not debug_info["match_expected"]
|
||||
@@ -107,7 +91,7 @@ def test_conversion_cli_reproduce(tmp_path: Path, base_config):
|
||||
cli_args = ConvertDiffTransformerCliArgs(
|
||||
debug=True, zero_init=True, sublayer_norm=False
|
||||
)
|
||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is True
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
@@ -133,7 +117,7 @@ def test_conversion_cli_repoduce_attentions(
|
||||
cli_args = ConvertDiffTransformerCliArgs(
|
||||
debug=True, zero_init=True, sublayer_norm=False
|
||||
)
|
||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is True
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
@@ -155,7 +139,7 @@ def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
|
||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is False
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
Reference in New Issue
Block a user