xformers attention with packing (#2619)
* xformers attention with packing * wire up the patch * fix xformers + packing validation * fix warning * reorder the packing check * fix fp16 / bf16 reset when using fp16 with bf16 auto * fix seq lens calc to drop hanging sequences * handle xformers patch for inference too * fix batch size setter * fix xformers inference * add colab callback to fix inference post train * PR feedback
This commit is contained in:
@@ -73,11 +73,12 @@ load_in_8bit: true
|
|||||||
load_in_4bit:
|
load_in_4bit:
|
||||||
|
|
||||||
# Use CUDA bf16
|
# Use CUDA bf16
|
||||||
bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
|
bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere
|
||||||
# Use CUDA fp16
|
# Use CUDA fp16
|
||||||
fp16: true
|
fp16: true
|
||||||
# Use CUDA tf32
|
# Use CUDA tf32
|
||||||
tf32: true # require >=ampere
|
tf32: true # require >=ampere
|
||||||
|
# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting
|
||||||
|
|
||||||
# No AMP (automatic mixed precision)
|
# No AMP (automatic mixed precision)
|
||||||
bfloat16: true # require >=ampere
|
bfloat16: true # require >=ampere
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import importlib.util
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -72,6 +73,7 @@ from axolotl.utils.callbacks import (
|
|||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
causal_lm_bench_eval_callback_factory,
|
causal_lm_bench_eval_callback_factory,
|
||||||
|
colab_inference_post_train_callback,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||||
@@ -293,6 +295,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||||
callbacks.append(lisa_callback_factory(trainer))
|
callbacks.append(lisa_callback_factory(trainer))
|
||||||
|
|
||||||
|
if any("COLAB_" in key for key in os.environ):
|
||||||
|
ColabCallback = colab_inference_post_train_callback(trainer)
|
||||||
|
callbacks.append(ColabCallback(self.cfg))
|
||||||
|
|
||||||
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
|
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
"""
|
||||||
|
attention module for attention monkeypatches
|
||||||
|
"""
|
||||||
|
|
||||||
|
from transformers.integrations.flash_attention import flash_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
def patch_xformers_attn_over_fa2():
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
|
from .xformers import xformers_attention_forward
|
||||||
|
|
||||||
|
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = xformers_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
def unpatch_xformers_attn_over_fa2():
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
|
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward()
|
||||||
|
|||||||
160
src/axolotl/monkeypatch/attention/xformers.py
Normal file
160
src/axolotl/monkeypatch/attention/xformers.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
"""
|
||||||
|
xformers attention implementation for packing
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import xformers
|
||||||
|
import xformers.ops.fmha
|
||||||
|
from transformers.modeling_flash_attention_utils import (
|
||||||
|
_upad_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
|
xformers_attention = xformers.ops.fmha.memory_efficient_attention
|
||||||
|
|
||||||
|
|
||||||
|
def xformers_attention_forward(
|
||||||
|
module: torch.nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
dropout: float = 0.0, # pylint: disable=unused-argument
|
||||||
|
scaling: Optional[float] = None, # pylint: disable=unused-argument
|
||||||
|
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
|
||||||
|
softcap: Optional[float] = None, # pylint: disable=unused-argument
|
||||||
|
cu_seq_lens_q: Optional[torch.LongTensor] = None,
|
||||||
|
cu_seq_lens_k: Optional[torch.LongTensor] = None,
|
||||||
|
max_length_q: Optional[int] = None,
|
||||||
|
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
# Get dimensions
|
||||||
|
# query: [batch, heads, seq_len, hidden_dim]
|
||||||
|
batch_size = query.size(0)
|
||||||
|
query_length = query.shape[2]
|
||||||
|
key_length = key.shape[2]
|
||||||
|
|
||||||
|
# Default causal mask
|
||||||
|
attn_bias = xformers.ops.LowerTriangularMask()
|
||||||
|
|
||||||
|
# Check if we have sliding window attention
|
||||||
|
has_sliding_window = sliding_window is not None and sliding_window < query_length
|
||||||
|
|
||||||
|
# Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d])
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
value = value.transpose(1, 2)
|
||||||
|
|
||||||
|
# Get GQA parameters
|
||||||
|
num_attention_heads = module.config.num_attention_heads
|
||||||
|
num_key_value_heads = module.config.num_key_value_heads
|
||||||
|
head_dim = query.size(-1)
|
||||||
|
is_gqa = num_attention_heads != num_key_value_heads
|
||||||
|
n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1
|
||||||
|
|
||||||
|
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||||
|
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||||
|
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||||
|
if position_ids is not None and (
|
||||||
|
max_length_q is not None
|
||||||
|
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
|
||||||
|
):
|
||||||
|
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
||||||
|
cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]
|
||||||
|
cu_seq_lens_q = cu_seq_lens_q.squeeze()
|
||||||
|
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
|
||||||
|
attn_bias = (
|
||||||
|
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||||
|
q_seqlen=seq_lengths.tolist(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = query.reshape(-1, query.size(-2), query.size(-1))
|
||||||
|
key = key.reshape(-1, key.size(-2), key.size(-1))
|
||||||
|
value = value.reshape(-1, value.size(-2), value.size(-1))
|
||||||
|
|
||||||
|
# Handle GQA
|
||||||
|
if is_gqa:
|
||||||
|
key = key.repeat_interleave(n_groups, dim=2)
|
||||||
|
value = value.repeat_interleave(n_groups, dim=2)
|
||||||
|
|
||||||
|
elif attention_mask is not None:
|
||||||
|
query, key, value, _, cu_seq_lens, _ = _upad_input(
|
||||||
|
query, key, value, attention_mask, query_length
|
||||||
|
)
|
||||||
|
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
|
||||||
|
seq_lengths = []
|
||||||
|
for i in range(len(cu_seq_lens_q) - 1):
|
||||||
|
seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i])
|
||||||
|
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||||
|
q_seqlen=seq_lengths,
|
||||||
|
kv_seqlen=seq_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle GQA
|
||||||
|
if is_gqa:
|
||||||
|
key = key.repeat_interleave(n_groups, dim=2)
|
||||||
|
value = value.repeat_interleave(n_groups, dim=2)
|
||||||
|
else:
|
||||||
|
# Handle Group Query Attention (GQA) using view/expand approach from reference
|
||||||
|
key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
|
||||||
|
value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
|
||||||
|
key = key.expand(
|
||||||
|
batch_size, key_length, num_key_value_heads, n_groups, head_dim
|
||||||
|
)
|
||||||
|
value = value.expand(
|
||||||
|
batch_size, key_length, num_key_value_heads, n_groups, head_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
if module.training:
|
||||||
|
key = key.reshape(batch_size, key_length, num_attention_heads, head_dim)
|
||||||
|
value = value.reshape(batch_size, key_length, num_attention_heads, head_dim)
|
||||||
|
|
||||||
|
if has_sliding_window:
|
||||||
|
query = query.view(
|
||||||
|
1, batch_size * query_length, num_attention_heads, head_dim
|
||||||
|
)
|
||||||
|
key = key.view(
|
||||||
|
1, batch_size * key_length, num_attention_heads, head_dim
|
||||||
|
)
|
||||||
|
value = value.view(
|
||||||
|
1, batch_size * key_length, num_attention_heads, head_dim
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = query.view(
|
||||||
|
batch_size, query_length, num_key_value_heads, n_groups, head_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we need a sliding window attention
|
||||||
|
if has_sliding_window:
|
||||||
|
query = query.view(
|
||||||
|
1,
|
||||||
|
batch_size * query_length,
|
||||||
|
num_key_value_heads,
|
||||||
|
n_groups,
|
||||||
|
head_dim,
|
||||||
|
)
|
||||||
|
key = key.view(
|
||||||
|
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
|
||||||
|
)
|
||||||
|
value = value.view(
|
||||||
|
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the xformers attention
|
||||||
|
attn_output = xformers_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_bias=attn_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(
|
||||||
|
batch_size, -1, attn_output.size(-2), attn_output.size(-1)
|
||||||
|
)
|
||||||
|
return attn_output, None
|
||||||
@@ -868,3 +868,28 @@ class GCCallback(TrainerCallback):
|
|||||||
):
|
):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
def colab_inference_post_train_callback(trainer: Trainer):
|
||||||
|
class ColabCallback(TrainerCallback):
|
||||||
|
"""Callback to prep model for inference on Google Colab"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.gpu_name = torch.cuda.get_device_name(0)
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
def on_train_end(
|
||||||
|
self, args, state, control, **kwargs
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
"""
|
||||||
|
handle T4 gpu, we need to convert attention to eager for inference
|
||||||
|
"""
|
||||||
|
if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention:
|
||||||
|
trainer.model.config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
"eager"
|
||||||
|
)
|
||||||
|
trainer.model.gradient_checkpointing_disable()
|
||||||
|
trainer.model.config.use_cache = True
|
||||||
|
trainer.model.eval()
|
||||||
|
|
||||||
|
return ColabCallback
|
||||||
|
|||||||
@@ -70,6 +70,9 @@ def resolve_dtype(cfg):
|
|||||||
if cfg.fp16 is None and not cfg.float16:
|
if cfg.fp16 is None and not cfg.float16:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
|
||||||
|
if cfg.fp16 and cfg.bf16 == "auto":
|
||||||
|
cfg.bf16 = False
|
||||||
|
|
||||||
if cfg.device == "mps":
|
if cfg.device == "mps":
|
||||||
cfg.load_in_8bit = False
|
cfg.load_in_8bit = False
|
||||||
cfg.tf32 = False
|
cfg.tf32 = False
|
||||||
|
|||||||
@@ -540,6 +540,11 @@ class ModelLoader:
|
|||||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||||
|
|
||||||
def apply_patches(self) -> None:
|
def apply_patches(self) -> None:
|
||||||
|
if self.cfg.xformers_attention and self.cfg.sample_packing:
|
||||||
|
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
|
||||||
|
|
||||||
|
patch_xformers_attn_over_fa2()
|
||||||
|
self.cfg.flash_attention = True
|
||||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
||||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
|
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
|
||||||
|
|
||||||
|
|||||||
@@ -435,16 +435,6 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_sample_packing_w_xformers(cls, data):
|
|
||||||
if data.get("sample_packing") and data.get("xformers_attention"):
|
|
||||||
raise ValueError(
|
|
||||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -471,9 +461,10 @@ class AxolotlInputConfig(
|
|||||||
and not data.get("flash_attention")
|
and not data.get("flash_attention")
|
||||||
and not data.get("sdp_attention")
|
and not data.get("sdp_attention")
|
||||||
and not data.get("flex_attention")
|
and not data.get("flex_attention")
|
||||||
|
and not data.get("xformers_attention")
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"sample_packing without flash, sdp or flex attention does not handle cross sample decontamination."
|
"sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination."
|
||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|||||||
Reference in New Issue
Block a user