differential flash attention 2; cleanup
This commit is contained in:
6
model-out/eval_summary.csv
Normal file
6
model-out/eval_summary.csv
Normal file
@@ -0,0 +1,6 @@
|
||||
metric,training,validation
|
||||
loss,1.8773103952407837,1.915901780128479
|
||||
model_preparation_time,0.0051,0.0051
|
||||
runtime,89.7635,8.9565
|
||||
samples_per_second,20.053,22.33
|
||||
steps_per_second,20.053,22.33
|
||||
|
@@ -14,7 +14,9 @@ from transformers import HfArgumentParser
|
||||
|
||||
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention
|
||||
from axolotl.integrations.differential_transformer.convert import (
|
||||
convert_to_diff_attention,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.convert_attention")
|
||||
|
||||
@@ -74,7 +76,11 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
# Convert attention
|
||||
LOG.info("Converting to differential attention...")
|
||||
try:
|
||||
model = convert_to_diff_attention(model, cli_args.zero_init)
|
||||
model = convert_to_diff_attention(
|
||||
model=model,
|
||||
zero_init=cli_args.zero_init,
|
||||
sublayer_norm=cli_args.sublayer_norm,
|
||||
)
|
||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
except Exception as exc:
|
||||
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
||||
@@ -130,43 +136,35 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
+ Fore.RESET
|
||||
)
|
||||
else:
|
||||
if cli_args.zero_init:
|
||||
LOG.info(
|
||||
Fore.RED
|
||||
+ "Generations do not match.\n"
|
||||
+ "Original generation:\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ f"{orig_text}\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ "Converted generation:\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ f"{conv_text}\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ Fore.RESET
|
||||
)
|
||||
message = (
|
||||
"Generations do not match.\n"
|
||||
+ "Original generation:\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ f"{orig_text}\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ "Converted generation:\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ f"{conv_text}\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
if cli_args.zero_init and not cli_args.sublayer_norm:
|
||||
LOG.info(Fore.RED + message + Fore.RESET)
|
||||
else:
|
||||
LOG.info(
|
||||
Fore.YELLOW
|
||||
+ "Generations do not match.\n"
|
||||
+ "Original generation:\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ f"{orig_text}\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ "Converted generation:\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ f"{conv_text}\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ "However, this is expected since --zero-init was not passed."
|
||||
+ message
|
||||
+ "However, this is expected since --zero-init"
|
||||
+ " and --no-sublayer-norm were not passed."
|
||||
+ Fore.RESET
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
except Exception as exc:
|
||||
LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc))
|
||||
raise
|
||||
|
||||
@@ -22,7 +22,6 @@ def add_options_from_dataclass(config_class: Type[Any]):
|
||||
# Process dataclass fields in reverse order for correct option ordering
|
||||
for field in reversed(dataclasses.fields(config_class)):
|
||||
field_type = field.type
|
||||
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
@@ -44,6 +43,7 @@ def add_options_from_dataclass(config_class: Type[Any]):
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
@@ -55,7 +55,14 @@ def add_options_from_config(config_class: Type[BaseModel]):
|
||||
def decorator(function):
|
||||
# Process model fields in reverse order for correct option ordering
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
if field.annotation == bool:
|
||||
field_type = field.annotation
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
|
||||
# NOTE: defaults are handled by the pydantic model config classes.
|
||||
if field_type == bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
@@ -66,6 +73,7 @@ def add_options_from_config(config_class: Type[BaseModel]):
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -62,6 +62,7 @@ class ConvertDiffTransformerCliArgs:
|
||||
|
||||
debug: bool = field(default=False)
|
||||
zero_init: bool = field(default=False)
|
||||
sublayer_norm: bool = field(default=True)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
|
||||
@@ -5,22 +5,35 @@ from typing import Union
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaSdpaAttention
|
||||
from transformers.models.mistral.modeling_mistral import MistralAttention
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaFlashAttention2,
|
||||
LlamaSdpaAttention,
|
||||
)
|
||||
|
||||
from .multihead_diffattn import (
|
||||
from .differential_attention import (
|
||||
LlamaDifferentialAttention,
|
||||
LlamaDifferentialFlashAttention2,
|
||||
LlamaDifferentialSdpaAttention,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ATTENTION_MAPPING = {
|
||||
LlamaAttention: LlamaDifferentialAttention,
|
||||
LlamaSdpaAttention: LlamaDifferentialSdpaAttention,
|
||||
LlamaFlashAttention2: LlamaDifferentialFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
def copy_attention_weights(
|
||||
old_attn: Union[LlamaAttention, LlamaSdpaAttention],
|
||||
new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention],
|
||||
old_attn: Union[LlamaAttention, LlamaSdpaAttention, LlamaFlashAttention2],
|
||||
new_attn: Union[
|
||||
LlamaDifferentialAttention,
|
||||
LlamaDifferentialSdpaAttention,
|
||||
LlamaDifferentialFlashAttention2,
|
||||
],
|
||||
zero_init: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -69,31 +82,24 @@ def copy_attention_weights(
|
||||
|
||||
|
||||
def convert_to_diff_attention(
|
||||
model: PreTrainedModel, zero_init: bool
|
||||
model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True
|
||||
) -> PreTrainedModel:
|
||||
"""Convert a pre-trained model's attention layers to differential attention"""
|
||||
attention_patterns = (
|
||||
LlamaAttention,
|
||||
LlamaSdpaAttention,
|
||||
MistralAttention,
|
||||
MixtralAttention,
|
||||
)
|
||||
layer_idx = 0
|
||||
|
||||
# Set sublayer norm as config on the model.
|
||||
model.config.sublayer_norm = sublayer_norm
|
||||
|
||||
def convert_module(module):
|
||||
nonlocal layer_idx
|
||||
|
||||
# Iterate through module children, convert any attn layers to diff attn
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, attention_patterns):
|
||||
layer_type = type(child).__name__
|
||||
|
||||
if isinstance(child, tuple(ATTENTION_MAPPING.keys())):
|
||||
# Choose appropriate differential attention class
|
||||
if isinstance(child, LlamaSdpaAttention):
|
||||
attention_class = LlamaDifferentialSdpaAttention
|
||||
else:
|
||||
attention_class = LlamaDifferentialAttention
|
||||
attention_class = ATTENTION_MAPPING[type(child)]
|
||||
|
||||
layer_type = type(child).__name__
|
||||
logger.info(
|
||||
f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}"
|
||||
)
|
||||
@@ -7,9 +7,11 @@ from typing import Any, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaRMSNorm,
|
||||
LlamaRotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
@@ -75,14 +77,11 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
dtype = torch.float32
|
||||
|
||||
# For Q1 and Q2
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size * 2,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# For K1 and K2
|
||||
@@ -90,7 +89,6 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Single V projection
|
||||
@@ -98,7 +96,6 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Output projection
|
||||
@@ -106,28 +103,33 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Initialize differential attention parameters
|
||||
self.lambda_init = nn.Parameter(
|
||||
torch.full((), lambda_init_fn(self.layer_idx), dtype=dtype),
|
||||
torch.full((), lambda_init_fn(self.layer_idx)),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.lambda_q1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
self.lambda_k1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
self.lambda_q2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
self.lambda_k2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
|
||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||
)
|
||||
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||
sublayer_norm = getattr(config, "sublayer_norm", True)
|
||||
self.subln = (
|
||||
LlamaRMSNorm(hidden_size=self.head_dim, eps=1e-5)
|
||||
if sublayer_norm
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -192,39 +194,21 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
# Calculate attention scores for both parts
|
||||
# NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor
|
||||
# instead of sqrt(head_dim). This could be set on the class as `self.scaling`.
|
||||
attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(
|
||||
self.head_dim
|
||||
)
|
||||
attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(
|
||||
self.head_dim
|
||||
)
|
||||
|
||||
# Add this debug step right after computing attention weights in the forward pass
|
||||
attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(
|
||||
self.head_dim
|
||||
)
|
||||
attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(
|
||||
self.head_dim
|
||||
)
|
||||
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
||||
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
|
||||
attn_weights1 = attn_weights1 + causal_mask
|
||||
attn_weights2 = attn_weights2 + causal_mask
|
||||
attn1 = attn1 + causal_mask
|
||||
attn2 = attn2 + causal_mask
|
||||
|
||||
# Apply softmax separately as per paper
|
||||
attn_weights1 = F.softmax(attn_weights1, dim=-1, dtype=torch.float32).type_as(
|
||||
attn_weights1
|
||||
)
|
||||
attn_weights2 = F.softmax(attn_weights2, dim=-1, dtype=torch.float32).type_as(
|
||||
attn_weights2
|
||||
)
|
||||
attn_weights1 = F.dropout(
|
||||
attn_weights1, p=self.attention_dropout, training=self.training
|
||||
)
|
||||
attn_weights2 = F.dropout(
|
||||
attn_weights2, p=self.attention_dropout, training=self.training
|
||||
)
|
||||
# Apply softmax
|
||||
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
|
||||
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
|
||||
|
||||
# Apply dropout
|
||||
attn1 = F.dropout(attn1, p=self.attention_dropout, training=self.training)
|
||||
attn2 = F.dropout(attn2, p=self.attention_dropout, training=self.training)
|
||||
|
||||
# Calculate lambda
|
||||
lambda_1 = torch.exp(
|
||||
@@ -236,15 +220,13 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
# Compute differential attention (following paper's formula)
|
||||
attn_weights = attn_weights1 - lambda_full * attn_weights2
|
||||
attn_weights = attn1 - lambda_full * attn2
|
||||
|
||||
# Apply attention weights to values
|
||||
attn = torch.matmul(attn_weights, v)
|
||||
|
||||
# Apply sublayer norm and scaling
|
||||
# NOTE(Dan): The differential transformers paper applies sublayer normalization at this
|
||||
# point, but this is typically done outside of the attention layer. It would look something
|
||||
# like: `attn = self.subln(attn).type_as(attn)`, using `LlamaRMSNorm` or similar.
|
||||
attn = self.subln(attn)
|
||||
attn = attn * (1 - self.lambda_init)
|
||||
|
||||
# Reshape to output
|
||||
@@ -368,20 +350,21 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
|
||||
# Calculate attention using SDPA
|
||||
is_causal = attention_mask is None and q_len > 1
|
||||
|
||||
attn_output1 = F.scaled_dot_product_attention(
|
||||
dropout_p = self.attention_dropout if self.training else 0.0
|
||||
attn1 = F.scaled_dot_product_attention(
|
||||
q1,
|
||||
k1,
|
||||
v,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
attn_output2 = F.scaled_dot_product_attention(
|
||||
attn2 = F.scaled_dot_product_attention(
|
||||
q2,
|
||||
k2,
|
||||
v,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
@@ -395,9 +378,10 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
# Combine the attention outputs
|
||||
attn = attn_output1 - lambda_full * attn_output2
|
||||
attn = attn1 - lambda_full * attn2
|
||||
|
||||
# Apply sublayer norm and scaling
|
||||
attn = self.subln(attn)
|
||||
attn = attn * (1 - self.lambda_init)
|
||||
|
||||
# Reshape to output
|
||||
@@ -411,3 +395,157 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
|
||||
past_key_value,
|
||||
) # Note: can't return attn_weights with SDPA
|
||||
return attn, None, past_key_value
|
||||
|
||||
|
||||
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
|
||||
"""Differential Attention implementation using Flash Attention 2.
|
||||
This implements the same logic as `LlamaDifferentialAttention`, but uses
|
||||
Flash Attention 2 for more efficient computation.
|
||||
|
||||
This implements a modified attention mechanism that computes the difference between
|
||||
two attention patterns, scaled by learned lambda parameters. The mechanism helps
|
||||
reduce noise in the attention weights for irrelevant / less relevant tokens.
|
||||
|
||||
Key components:
|
||||
- Split head dimension for differential computation
|
||||
- Learned lambda parameters that control attention scaling
|
||||
- Sublayer normalization on the attention output
|
||||
- Flash Attention 2 for efficient attention computation
|
||||
|
||||
See:
|
||||
- https://arxiv.org/abs/2410.05258
|
||||
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
|
||||
|
||||
Args:
|
||||
config: Model configuration object containing hidden size, number of heads etc.
|
||||
layer_idx: Index of this layer in the transformer stack
|
||||
dtype: Data type for the layer parameters
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size]
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[tuple[torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
if output_attentions:
|
||||
transformers.logger.warning_once(
|
||||
"LlamaModel is using LlamaFlashAttention, but Flash Attention does not support `output_attentions=True`. "
|
||||
"Falling back to the manual attention implementation."
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Project to Q1,Q2 and K1,K2
|
||||
qp = self.q_proj(hidden_states)
|
||||
kp = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
# Split into Q1,Q2 and K1,K2
|
||||
q1, q2 = qp.chunk(2, dim=-1)
|
||||
k1, k2 = kp.chunk(2, dim=-1)
|
||||
|
||||
# Reshape Q1,Q2 for attention
|
||||
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
|
||||
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
|
||||
# Reshape K1,K2 for attention
|
||||
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
# Reshape V
|
||||
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary embeddings
|
||||
if position_embeddings is None:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(q_len, device=q1.device)
|
||||
cos, sin = self.rotary_emb(q1, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
k = torch.stack([k1, k2], dim=1)
|
||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||
k1, k2 = k.unbind(dim=1)
|
||||
|
||||
# Repeat KV heads to match Q heads
|
||||
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
||||
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
||||
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
||||
|
||||
q1 = q1.transpose(1, 2)
|
||||
q2 = q2.transpose(1, 2)
|
||||
k1 = k1.transpose(1, 2)
|
||||
k2 = k2.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# Calculate attention using Flash Attention
|
||||
dropout_p = self.attention_dropout if self.training else 0.0
|
||||
attn1 = flash_attn_func(
|
||||
q1,
|
||||
k1,
|
||||
v,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
attn2 = flash_attn_func(
|
||||
q2,
|
||||
k2,
|
||||
v,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
attn1 = attn1.transpose(1, 2)
|
||||
attn2 = attn2.transpose(1, 2)
|
||||
|
||||
# Calculate lambda
|
||||
lambda_1 = torch.exp(
|
||||
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_2 = torch.exp(
|
||||
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
|
||||
).type_as(q1)
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
# Combine the attention outputs
|
||||
attn = attn1 - lambda_full * attn2
|
||||
|
||||
# Apply sublayer norm and scaling
|
||||
attn = self.subln(attn)
|
||||
attn = attn * (1 - self.lambda_init)
|
||||
|
||||
# Reshape to output
|
||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
||||
attn = self.o_proj(attn)
|
||||
|
||||
if output_attentions:
|
||||
return (
|
||||
attn,
|
||||
None,
|
||||
past_key_value,
|
||||
) # Note: can't return attn_weights with Flash Attention
|
||||
return attn, None, past_key_value
|
||||
@@ -3,8 +3,9 @@
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
|
||||
|
||||
from axolotl.integrations.diff_transformer.multihead_diffattn import (
|
||||
from axolotl.integrations.differential_transformer.differential_attention import (
|
||||
LlamaDifferentialAttention,
|
||||
LlamaDifferentialFlashAttention2,
|
||||
LlamaDifferentialSdpaAttention,
|
||||
)
|
||||
|
||||
@@ -15,6 +16,9 @@ def patch_llama_attention_classes():
|
||||
# Add our attention class to the registry
|
||||
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
|
||||
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
|
||||
LLAMA_ATTENTION_CLASSES[
|
||||
"differential_flash_attention_2"
|
||||
] = LlamaDifferentialFlashAttention2
|
||||
|
||||
@classmethod
|
||||
def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument
|
||||
@@ -28,6 +32,7 @@ def patch_llama_attention_classes():
|
||||
"flash_attention_2",
|
||||
"differential_eager",
|
||||
"differential_sdpa",
|
||||
"differential_flash_attention_2",
|
||||
]
|
||||
if attn_implementation not in valid_impls:
|
||||
message = (
|
||||
|
||||
Reference in New Issue
Block a user