xformers attention with packing (#2619)

* xformers attention with packing

* wire up the patch

* fix xformers + packing validation

* fix warning

* reorder the packing check

* fix fp16 / bf16 reset when using fp16 with bf16 auto

* fix seq lens calc to drop hanging sequences

* handle xformers patch for inference too

* fix batch size setter

* fix xformers inference

* add colab callback to fix inference post train

* PR feedback
This commit is contained in:
Wing Lian
2025-05-06 22:49:22 -04:00
parent fcbd7477d0
commit cae5cebb59
8 changed files with 222 additions and 12 deletions

View File

@@ -73,11 +73,12 @@ load_in_8bit: true
load_in_4bit: load_in_4bit:
# Use CUDA bf16 # Use CUDA bf16
bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere
# Use CUDA fp16 # Use CUDA fp16
fp16: true fp16: true
# Use CUDA tf32 # Use CUDA tf32
tf32: true # require >=ampere tf32: true # require >=ampere
# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting
# No AMP (automatic mixed precision) # No AMP (automatic mixed precision)
bfloat16: true # require >=ampere bfloat16: true # require >=ampere

View File

@@ -21,6 +21,7 @@ import importlib.util
import inspect import inspect
import logging import logging
import math import math
import os
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from pathlib import Path from pathlib import Path
@@ -72,6 +73,7 @@ from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory, causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory, log_prediction_callback_factory,
) )
from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.lisa import lisa_callback_factory
@@ -293,6 +295,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer)) callbacks.append(lisa_callback_factory(trainer))
if any("COLAB_" in key for key in os.environ):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer)) callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks return callbacks

View File

@@ -0,0 +1,19 @@
"""
attention module for attention monkeypatches
"""
from transformers.integrations.flash_attention import flash_attention_forward
def patch_xformers_attn_over_fa2():
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from .xformers import xformers_attention_forward
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = xformers_attention_forward
def unpatch_xformers_attn_over_fa2():
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward()

View File

@@ -0,0 +1,160 @@
"""
xformers attention implementation for packing
"""
from typing import Optional
import torch
import xformers
import xformers.ops.fmha
from transformers.modeling_flash_attention_utils import (
_upad_input,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
xformers_attention = xformers.ops.fmha.memory_efficient_attention
def xformers_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
dropout: float = 0.0, # pylint: disable=unused-argument
scaling: Optional[float] = None, # pylint: disable=unused-argument
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
softcap: Optional[float] = None, # pylint: disable=unused-argument
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# Get dimensions
# query: [batch, heads, seq_len, hidden_dim]
batch_size = query.size(0)
query_length = query.shape[2]
key_length = key.shape[2]
# Default causal mask
attn_bias = xformers.ops.LowerTriangularMask()
# Check if we have sliding window attention
has_sliding_window = sliding_window is not None and sliding_window < query_length
# Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d])
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Get GQA parameters
num_attention_heads = module.config.num_attention_heads
num_key_value_heads = module.config.num_key_value_heads
head_dim = query.size(-1)
is_gqa = num_attention_heads != num_key_value_heads
n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
if position_ids is not None and (
max_length_q is not None
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
if cu_seq_lens_q is None or cu_seq_lens_k is None:
cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]
cu_seq_lens_q = cu_seq_lens_q.squeeze()
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
attn_bias = (
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=seq_lengths.tolist(),
)
)
else:
query = query.reshape(-1, query.size(-2), query.size(-1))
key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1))
# Handle GQA
if is_gqa:
key = key.repeat_interleave(n_groups, dim=2)
value = value.repeat_interleave(n_groups, dim=2)
elif attention_mask is not None:
query, key, value, _, cu_seq_lens, _ = _upad_input(
query, key, value, attention_mask, query_length
)
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
seq_lengths = []
for i in range(len(cu_seq_lens_q) - 1):
seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i])
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=seq_lengths,
kv_seqlen=seq_lengths,
)
# Handle GQA
if is_gqa:
key = key.repeat_interleave(n_groups, dim=2)
value = value.repeat_interleave(n_groups, dim=2)
else:
# Handle Group Query Attention (GQA) using view/expand approach from reference
key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
key = key.expand(
batch_size, key_length, num_key_value_heads, n_groups, head_dim
)
value = value.expand(
batch_size, key_length, num_key_value_heads, n_groups, head_dim
)
if module.training:
key = key.reshape(batch_size, key_length, num_attention_heads, head_dim)
value = value.reshape(batch_size, key_length, num_attention_heads, head_dim)
if has_sliding_window:
query = query.view(
1, batch_size * query_length, num_attention_heads, head_dim
)
key = key.view(
1, batch_size * key_length, num_attention_heads, head_dim
)
value = value.view(
1, batch_size * key_length, num_attention_heads, head_dim
)
else:
query = query.view(
batch_size, query_length, num_key_value_heads, n_groups, head_dim
)
# If we need a sliding window attention
if has_sliding_window:
query = query.view(
1,
batch_size * query_length,
num_key_value_heads,
n_groups,
head_dim,
)
key = key.view(
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
)
value = value.view(
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
)
# Run the xformers attention
attn_output = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
attn_output = attn_output.view(
batch_size, -1, attn_output.size(-2), attn_output.size(-1)
)
return attn_output, None

View File

@@ -868,3 +868,28 @@ class GCCallback(TrainerCallback):
): ):
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
def colab_inference_post_train_callback(trainer: Trainer):
class ColabCallback(TrainerCallback):
"""Callback to prep model for inference on Google Colab"""
def __init__(self, cfg):
self.gpu_name = torch.cuda.get_device_name(0)
self.cfg = cfg
def on_train_end(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
"""
handle T4 gpu, we need to convert attention to eager for inference
"""
if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention:
trainer.model.config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
trainer.model.gradient_checkpointing_disable()
trainer.model.config.use_cache = True
trainer.model.eval()
return ColabCallback

View File

@@ -70,6 +70,9 @@ def resolve_dtype(cfg):
if cfg.fp16 is None and not cfg.float16: if cfg.fp16 is None and not cfg.float16:
cfg.fp16 = True cfg.fp16 = True
if cfg.fp16 and cfg.bf16 == "auto":
cfg.bf16 = False
if cfg.device == "mps": if cfg.device == "mps":
cfg.load_in_8bit = False cfg.load_in_8bit = False
cfg.tf32 = False cfg.tf32 = False

View File

@@ -540,6 +540,11 @@ class ModelLoader:
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None: def apply_patches(self) -> None:
if self.cfg.xformers_attention and self.cfg.sample_packing:
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
patch_xformers_attn_over_fa2()
self.cfg.flash_attention = True
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2": if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils

View File

@@ -435,16 +435,6 @@ class AxolotlInputConfig(
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_xformers(cls, data):
if data.get("sample_packing") and data.get("xformers_attention"):
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -471,9 +461,10 @@ class AxolotlInputConfig(
and not data.get("flash_attention") and not data.get("flash_attention")
and not data.get("sdp_attention") and not data.get("sdp_attention")
and not data.get("flex_attention") and not data.get("flex_attention")
and not data.get("xformers_attention")
): ):
LOG.warning( LOG.warning(
"sample_packing without flash, sdp or flex attention does not handle cross sample decontamination." "sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination."
) )
return data return data