differential flash attention 2; cleanup

This commit is contained in:
Dan Saunders
2024-12-17 18:44:47 +00:00
parent 594c42f169
commit 6425d052bc
8 changed files with 268 additions and 106 deletions

View File

@@ -0,0 +1,6 @@
metric,training,validation
loss,1.8773103952407837,1.915901780128479
model_preparation_time,0.0051,0.0051
runtime,89.7635,8.9565
samples_per_second,20.053,22.33
steps_per_second,20.053,22.33
1 metric training validation
2 loss 1.8773103952407837 1.915901780128479
3 model_preparation_time 0.0051 0.0051
4 runtime 89.7635 8.9565
5 samples_per_second 20.053 22.33
6 steps_per_second 20.053 22.33

View File

@@ -14,7 +14,9 @@ from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention
from axolotl.integrations.differential_transformer.convert import (
convert_to_diff_attention,
)
LOG = logging.getLogger("axolotl.cli.convert_attention")
@@ -74,7 +76,11 @@ def convert_diff_transformer(cfg, cli_args, config_path):
# Convert attention
LOG.info("Converting to differential attention...")
try:
model = convert_to_diff_attention(model, cli_args.zero_init)
model = convert_to_diff_attention(
model=model,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
@@ -130,43 +136,35 @@ def convert_diff_transformer(cfg, cli_args, config_path):
+ Fore.RESET
)
else:
if cli_args.zero_init:
LOG.info(
Fore.RED
+ "Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{orig_text}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{conv_text}\n"
+ "*" * 50
+ "\n"
+ Fore.RESET
)
message = (
"Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{orig_text}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{conv_text}\n"
+ "*" * 50
+ "\n"
)
if cli_args.zero_init and not cli_args.sublayer_norm:
LOG.info(Fore.RED + message + Fore.RESET)
else:
LOG.info(
Fore.YELLOW
+ "Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{orig_text}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{conv_text}\n"
+ "*" * 50
+ "\n"
+ "However, this is expected since --zero-init was not passed."
+ message
+ "However, this is expected since --zero-init"
+ " and --no-sublayer-norm were not passed."
+ Fore.RESET
)
return model
except Exception as exc:
LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc))
raise

View File

@@ -22,7 +22,6 @@ def add_options_from_dataclass(config_class: Type[Any]):
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = field.type
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
@@ -44,6 +43,7 @@ def add_options_from_dataclass(config_class: Type[Any]):
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
@@ -55,7 +55,14 @@ def add_options_from_config(config_class: Type[BaseModel]):
def decorator(function):
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
field_type = field.annotation
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
# NOTE: defaults are handled by the pydantic model config classes.
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
@@ -66,6 +73,7 @@ def add_options_from_config(config_class: Type[BaseModel]):
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator

View File

@@ -62,6 +62,7 @@ class ConvertDiffTransformerCliArgs:
debug: bool = field(default=False)
zero_init: bool = field(default=False)
sublayer_norm: bool = field(default=True)
def load_model_and_tokenizer(

View File

@@ -5,22 +5,35 @@ from typing import Union
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaSdpaAttention
from transformers.models.mistral.modeling_mistral import MistralAttention
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
)
from .multihead_diffattn import (
from .differential_attention import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
ATTENTION_MAPPING = {
LlamaAttention: LlamaDifferentialAttention,
LlamaSdpaAttention: LlamaDifferentialSdpaAttention,
LlamaFlashAttention2: LlamaDifferentialFlashAttention2,
}
def copy_attention_weights(
old_attn: Union[LlamaAttention, LlamaSdpaAttention],
new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention],
old_attn: Union[LlamaAttention, LlamaSdpaAttention, LlamaFlashAttention2],
new_attn: Union[
LlamaDifferentialAttention,
LlamaDifferentialSdpaAttention,
LlamaDifferentialFlashAttention2,
],
zero_init: bool = False,
) -> None:
"""
@@ -69,31 +82,24 @@ def copy_attention_weights(
def convert_to_diff_attention(
model: PreTrainedModel, zero_init: bool
model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True
) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""
attention_patterns = (
LlamaAttention,
LlamaSdpaAttention,
MistralAttention,
MixtralAttention,
)
layer_idx = 0
# Set sublayer norm as config on the model.
model.config.sublayer_norm = sublayer_norm
def convert_module(module):
nonlocal layer_idx
# Iterate through module children, convert any attn layers to diff attn
for name, child in module.named_children():
if isinstance(child, attention_patterns):
layer_type = type(child).__name__
if isinstance(child, tuple(ATTENTION_MAPPING.keys())):
# Choose appropriate differential attention class
if isinstance(child, LlamaSdpaAttention):
attention_class = LlamaDifferentialSdpaAttention
else:
attention_class = LlamaDifferentialAttention
attention_class = ATTENTION_MAPPING[type(child)]
layer_type = type(child).__name__
logger.info(
f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}"
)

View File

@@ -7,9 +7,11 @@ from typing import Any, Optional, Tuple
import torch
import torch.nn.functional as F
import transformers
from flash_attn.flash_attn_interface import flash_attn_func
from torch import nn
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
)
@@ -75,14 +77,11 @@ class LlamaDifferentialAttention(nn.Module):
self.rope_theta = config.rope_theta
self.is_causal = True
dtype = torch.float32
# For Q1 and Q2
self.q_proj = nn.Linear(
self.hidden_size,
self.hidden_size * 2,
bias=False,
dtype=dtype,
)
# For K1 and K2
@@ -90,7 +89,6 @@ class LlamaDifferentialAttention(nn.Module):
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
bias=False,
dtype=dtype,
)
# Single V projection
@@ -98,7 +96,6 @@ class LlamaDifferentialAttention(nn.Module):
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
bias=False,
dtype=dtype,
)
# Output projection
@@ -106,28 +103,33 @@ class LlamaDifferentialAttention(nn.Module):
self.hidden_size,
self.hidden_size,
bias=False,
dtype=dtype,
)
# Initialize differential attention parameters
self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx), dtype=dtype),
torch.full((), lambda_init_fn(self.layer_idx)),
requires_grad=False,
)
self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_k1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_q2 = nn.Parameter(
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
sublayer_norm = getattr(config, "sublayer_norm", True)
self.subln = (
LlamaRMSNorm(hidden_size=self.head_dim, eps=1e-5)
if sublayer_norm
else nn.Identity()
)
def forward(
self,
@@ -192,39 +194,21 @@ class LlamaDifferentialAttention(nn.Module):
# Calculate attention scores for both parts
# NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor
# instead of sqrt(head_dim). This could be set on the class as `self.scaling`.
attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(
self.head_dim
)
attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(
self.head_dim
)
# Add this debug step right after computing attention weights in the forward pass
attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(
self.head_dim
)
attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(
self.head_dim
)
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
attn_weights1 = attn_weights1 + causal_mask
attn_weights2 = attn_weights2 + causal_mask
attn1 = attn1 + causal_mask
attn2 = attn2 + causal_mask
# Apply softmax separately as per paper
attn_weights1 = F.softmax(attn_weights1, dim=-1, dtype=torch.float32).type_as(
attn_weights1
)
attn_weights2 = F.softmax(attn_weights2, dim=-1, dtype=torch.float32).type_as(
attn_weights2
)
attn_weights1 = F.dropout(
attn_weights1, p=self.attention_dropout, training=self.training
)
attn_weights2 = F.dropout(
attn_weights2, p=self.attention_dropout, training=self.training
)
# Apply softmax
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
# Apply dropout
attn1 = F.dropout(attn1, p=self.attention_dropout, training=self.training)
attn2 = F.dropout(attn2, p=self.attention_dropout, training=self.training)
# Calculate lambda
lambda_1 = torch.exp(
@@ -236,15 +220,13 @@ class LlamaDifferentialAttention(nn.Module):
lambda_full = lambda_1 - lambda_2 + self.lambda_init
# Compute differential attention (following paper's formula)
attn_weights = attn_weights1 - lambda_full * attn_weights2
attn_weights = attn1 - lambda_full * attn2
# Apply attention weights to values
attn = torch.matmul(attn_weights, v)
# Apply sublayer norm and scaling
# NOTE(Dan): The differential transformers paper applies sublayer normalization at this
# point, but this is typically done outside of the attention layer. It would look something
# like: `attn = self.subln(attn).type_as(attn)`, using `LlamaRMSNorm` or similar.
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
# Reshape to output
@@ -368,20 +350,21 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
# Calculate attention using SDPA
is_causal = attention_mask is None and q_len > 1
attn_output1 = F.scaled_dot_product_attention(
dropout_p = self.attention_dropout if self.training else 0.0
attn1 = F.scaled_dot_product_attention(
q1,
k1,
v,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
dropout_p=dropout_p,
is_causal=is_causal,
)
attn_output2 = F.scaled_dot_product_attention(
attn2 = F.scaled_dot_product_attention(
q2,
k2,
v,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
dropout_p=dropout_p,
is_causal=is_causal,
)
@@ -395,9 +378,10 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
lambda_full = lambda_1 - lambda_2 + self.lambda_init
# Combine the attention outputs
attn = attn_output1 - lambda_full * attn_output2
attn = attn1 - lambda_full * attn2
# Apply sublayer norm and scaling
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
# Reshape to output
@@ -411,3 +395,157 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
past_key_value,
) # Note: can't return attn_weights with SDPA
return attn, None, past_key_value
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
"""Differential Attention implementation using Flash Attention 2.
This implements the same logic as `LlamaDifferentialAttention`, but uses
Flash Attention 2 for more efficient computation.
This implements a modified attention mechanism that computes the difference between
two attention patterns, scaled by learned lambda parameters. The mechanism helps
reduce noise in the attention weights for irrelevant / less relevant tokens.
Key components:
- Split head dimension for differential computation
- Learned lambda parameters that control attention scaling
- Sublayer normalization on the attention output
- Flash Attention 2 for efficient attention computation
See:
- https://arxiv.org/abs/2410.05258
- https://github.com/microsoft/unilm/tree/master/Diff-Transformer
Args:
config: Model configuration object containing hidden size, number of heads etc.
layer_idx: Index of this layer in the transformer stack
dtype: Data type for the layer parameters
"""
def forward(
self,
hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size]
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[tuple[torch.Tensor, torch.Tensor]],
]:
if output_attentions:
transformers.logger.warning_once(
"LlamaModel is using LlamaFlashAttention, but Flash Attention does not support `output_attentions=True`. "
"Falling back to the manual attention implementation."
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
# Project to Q1,Q2 and K1,K2
qp = self.q_proj(hidden_states)
kp = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Split into Q1,Q2 and K1,K2
q1, q2 = qp.chunk(2, dim=-1)
k1, k2 = kp.chunk(2, dim=-1)
# Reshape Q1,Q2 for attention
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
# Reshape K1,K2 for attention
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
# Reshape V
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
# Apply rotary embeddings
if position_embeddings is None:
if position_ids is None:
position_ids = torch.arange(q_len, device=q1.device)
cos, sin = self.rotary_emb(q1, position_ids)
else:
cos, sin = position_embeddings
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k = torch.stack([k1, k2], dim=1)
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1)
# Repeat KV heads to match Q heads
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
q1 = q1.transpose(1, 2)
q2 = q2.transpose(1, 2)
k1 = k1.transpose(1, 2)
k2 = k2.transpose(1, 2)
v = v.transpose(1, 2)
# Calculate attention using Flash Attention
dropout_p = self.attention_dropout if self.training else 0.0
attn1 = flash_attn_func(
q1,
k1,
v,
dropout_p=dropout_p,
causal=True,
)
attn2 = flash_attn_func(
q2,
k2,
v,
dropout_p=dropout_p,
causal=True,
)
attn1 = attn1.transpose(1, 2)
attn2 = attn2.transpose(1, 2)
# Calculate lambda
lambda_1 = torch.exp(
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
).type_as(q1)
lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
# Combine the attention outputs
attn = attn1 - lambda_full * attn2
# Apply sublayer norm and scaling
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
# Reshape to output
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn = self.o_proj(attn)
if output_attentions:
return (
attn,
None,
past_key_value,
) # Note: can't return attn_weights with Flash Attention
return attn, None, past_key_value

View File

@@ -3,8 +3,9 @@
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
from axolotl.integrations.diff_transformer.multihead_diffattn import (
from axolotl.integrations.differential_transformer.differential_attention import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
@@ -15,6 +16,9 @@ def patch_llama_attention_classes():
# Add our attention class to the registry
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
LLAMA_ATTENTION_CLASSES[
"differential_flash_attention_2"
] = LlamaDifferentialFlashAttention2
@classmethod
def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument
@@ -28,6 +32,7 @@ def patch_llama_attention_classes():
"flash_attention_2",
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_implementation not in valid_impls:
message = (