feat: add sageattention (#2823) [skip ci]
* feat: add sageattention * feat: call path on pre model load * fix: patch to use register to correct var * fix: add strict check import at start * chore: fix comments * chore: refactor * feat: add capability check * fix: missed underscore * fix: let sageattention use FA backend in transformers * feat: update sage attention for attention mask and position ids * feat: allow sample packing but add warning without packing * fix: loss hitting 0 with packing and attention mask note * feat: downcast embeds if sage attention too * feat: add config validation * feat: add attention docs * chore: docs
This commit is contained in:
@@ -320,6 +320,7 @@ website:
|
||||
- docs/multipack.qmd
|
||||
- docs/mixed_precision.qmd
|
||||
- docs/optimizers.qmd
|
||||
- docs/attention.qmd
|
||||
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
|
||||
140
docs/attention.qmd
Normal file
140
docs/attention.qmd
Normal file
@@ -0,0 +1,140 @@
|
||||
---
|
||||
title: Attention
|
||||
description: Supported attention modules in Axolotl
|
||||
---
|
||||
|
||||
## SDP Attention
|
||||
|
||||
This is the default built-in attention in PyTorch.
|
||||
|
||||
```yaml
|
||||
sdp_attention: true
|
||||
```
|
||||
|
||||
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
|
||||
## Flash Attention 2
|
||||
|
||||
Uses efficient kernels to compute attention.
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
```
|
||||
|
||||
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
||||
|
||||
### Nvidia
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs
|
||||
|
||||
Note: For Turing GPUs or lower, please use other attention methods.
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
|
||||
|
||||
:::
|
||||
|
||||
#### Flash Attention 3
|
||||
|
||||
Requirements: Hopper only and CUDA 12.8 (recommended)
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/hopper
|
||||
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
### AMD
|
||||
|
||||
Requirements: ROCm 6.0 and above.
|
||||
|
||||
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
||||
|
||||
## Flex Attention
|
||||
|
||||
A flexible PyTorch API for attention used in combination with `torch.compile`.
|
||||
|
||||
```yaml
|
||||
flex_attention: true
|
||||
|
||||
# recommended
|
||||
torch_compile: true
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We recommend using latest stable version of PyTorch for best performance.
|
||||
|
||||
:::
|
||||
|
||||
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
|
||||
|
||||
## SageAttention
|
||||
|
||||
Attention kernels with QK Int8 and PV FP16 accumulator.
|
||||
|
||||
```yaml
|
||||
sage_attention: true
|
||||
```
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs
|
||||
|
||||
```bash
|
||||
pip install sageattention==2.2.0 --no-build-isolation
|
||||
```
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
||||
|
||||
:::
|
||||
|
||||
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
|
||||
|
||||
:::
|
||||
|
||||
|
||||
## xFormers
|
||||
|
||||
```yaml
|
||||
xformers_attention: true
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
We recommend using with Turing GPUs or below (such as on Colab).
|
||||
|
||||
:::
|
||||
|
||||
For more details: [xFormers](https://github.com/facebookresearch/xformers)
|
||||
|
||||
## Shifted Sparse Attention
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
|
||||
|
||||
:::
|
||||
|
||||
Requirements: LLaMA model architecture
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
s2_attention: true
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
No sample packing support!
|
||||
|
||||
:::
|
||||
@@ -338,7 +338,12 @@ class ModelLoader:
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
|
||||
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
(
|
||||
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
|
||||
(
|
||||
needs_fa2_dtype
|
||||
or self.cfg.flash_attention
|
||||
or self.cfg.flex_attention
|
||||
or self.cfg.sage_attention
|
||||
)
|
||||
and not self.is_qlora_and_fsdp_enabled
|
||||
)
|
||||
or (
|
||||
@@ -612,6 +617,10 @@ class ModelLoader:
|
||||
elif self.cfg.sdp_attention:
|
||||
self.model_kwargs["attn_implementation"] = "sdpa"
|
||||
self.model_config._attn_implementation = "sdpa"
|
||||
elif self.cfg.sage_attention:
|
||||
# sets FA2 attention to re-use same internal handling like masking
|
||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
self.model_config._attn_implementation = "flash_attention_2"
|
||||
elif self.cfg.eager_attention:
|
||||
self.model_kwargs["attn_implementation"] = "eager"
|
||||
self.model_config._attn_implementation = "eager"
|
||||
|
||||
@@ -96,6 +96,7 @@ class PatchManager:
|
||||
# self._apply_flex_attention_patches()
|
||||
self._apply_flash_attention_patches()
|
||||
self._apply_chunked_cross_entropy_patch()
|
||||
self._apply_sageattn_patches()
|
||||
self._apply_fsdp_patches()
|
||||
self._apply_adapter_patches()
|
||||
self._apply_model_specific_patches()
|
||||
@@ -201,6 +202,13 @@ class PatchManager:
|
||||
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||
|
||||
def _apply_sageattn_patches(self):
|
||||
"""Apply patches for SageAttention."""
|
||||
if self.cfg.sage_attention:
|
||||
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
|
||||
|
||||
patch_sageattn()
|
||||
|
||||
def _apply_model_specific_patches(self):
|
||||
"""Apply patches specific to model architectures."""
|
||||
if (
|
||||
|
||||
211
src/axolotl/monkeypatch/attention/sage_attn.py
Normal file
211
src/axolotl/monkeypatch/attention/sage_attn.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Monkeypatch for SageAttention for use with transformers.
|
||||
|
||||
https://github.com/thu-ml/SageAttention/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers.integrations.sdpa_attention import repeat_kv
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
sageattn = None # pylint: disable=invalid-name
|
||||
sageattn_varlen = None # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _is_sageattn_available():
|
||||
"""Determine if SageAttention is available"""
|
||||
try:
|
||||
import sageattention # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
if _is_sageattn_available():
|
||||
# import sageattn here if available
|
||||
from sageattention import sageattn, sageattn_varlen
|
||||
|
||||
|
||||
def _check_sageattn_imported():
|
||||
"""Check if SageAttention is imported. Raises an ImportError if not."""
|
||||
if sageattn is None:
|
||||
raise ImportError(
|
||||
"SageAttention is not installed. Please install it from source: "
|
||||
"`pip install git+https://github.com/thu-ml/SageAttention.git@1718ddc06dbc694bcf3c6b49ac28c1921aa2d8bd`"
|
||||
)
|
||||
|
||||
|
||||
def sage_attention_forward(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
dropout: float = 0.0,
|
||||
scaling: float | None = None,
|
||||
is_causal: bool | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
"""
|
||||
Forward pass for SageAttention compatible with transformers attention interfaces.
|
||||
|
||||
https://github.com/thu-ml/SageAttention/
|
||||
"""
|
||||
|
||||
_check_sageattn_imported()
|
||||
|
||||
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
|
||||
raise NotImplementedError(
|
||||
"SageAttention does not support `output_attentions=True` or `head_mask`."
|
||||
)
|
||||
|
||||
# The base sageattn API does not support dropout.
|
||||
if dropout > 0.0:
|
||||
raise NotImplementedError("SageAttention does not support dropout.")
|
||||
|
||||
# Handle Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)
|
||||
if hasattr(module, "num_key_value_groups"):
|
||||
key = repeat_kv(key, module.num_key_value_groups)
|
||||
value = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
# Calculate is_causal following transformers
|
||||
assert is_causal is not False, "is_causal must be True or None"
|
||||
is_causal = True
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
query_length = query.shape[2]
|
||||
|
||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", None)
|
||||
max_length_q = kwargs.get("max_length_q", None)
|
||||
max_length_k = kwargs.get("max_length_k", None)
|
||||
|
||||
# Sample packing uses position_ids, so we check for it first
|
||||
if position_ids is not None and (
|
||||
max_length_q is not None
|
||||
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
|
||||
):
|
||||
# transpose inputs to NHD layout for use with FA2 utils
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
batch_size = query.size(0)
|
||||
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
prepare_fa2_from_position_ids,
|
||||
)
|
||||
|
||||
if cu_seqlens_q is None or cu_seqlens_k is None:
|
||||
query, key, value, indices_q, cu_seq_lens, max_seq_lens = (
|
||||
prepare_fa2_from_position_ids(query, key, value, position_ids)
|
||||
)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_length_q, max_length_k = max_seq_lens
|
||||
|
||||
else:
|
||||
query = query.reshape(-1, query.size(-2), query.size(-1))
|
||||
key = key.reshape(-1, key.size(-2), key.size(-1))
|
||||
value = value.reshape(-1, value.size(-2), value.size(-1))
|
||||
|
||||
attn_output_unpad = sageattn_varlen(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_length_q,
|
||||
max_seqlen_k=max_length_k,
|
||||
is_causal=is_causal,
|
||||
sm_scale=scaling,
|
||||
smooth_k=False, # reduces loss 0 / nan grad norms
|
||||
tensor_layout="NHD",
|
||||
)
|
||||
|
||||
attn_output = attn_output_unpad.view(
|
||||
batch_size, -1, attn_output_unpad.size(-2), attn_output_unpad.size(-1)
|
||||
)
|
||||
|
||||
elif attention_mask is not None:
|
||||
# NOTE: When used without `pad_to_sequence_len`, the loss becomes unstable after a few steps.
|
||||
|
||||
assert attention_mask.ndim == 2, "Attention mask must be 2D"
|
||||
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_upad_input,
|
||||
)
|
||||
|
||||
# transpose inputs to NHD layout for use with FA2 utils
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
batch_size = query.shape[0]
|
||||
|
||||
query, key, value, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
|
||||
query, key, value, attention_mask, query_length
|
||||
)
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_q, max_seqlen_k = max_seq_lens
|
||||
|
||||
attn_output_unpad = sageattn_varlen(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
is_causal=is_causal,
|
||||
sm_scale=scaling,
|
||||
tensor_layout="NHD",
|
||||
)
|
||||
|
||||
from flash_attn.bert_padding import pad_input
|
||||
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
else:
|
||||
# Use standard sageattn
|
||||
# The input layout for transformers models is (batch_size, num_heads, seq_len, head_dim),
|
||||
# which corresponds to SageAttention's "HND" layout.
|
||||
attn_output = sageattn(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="HND",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scaling,
|
||||
)
|
||||
|
||||
# SageAttention with "HND" returns (batch, heads, seq_len, head_dim)
|
||||
# Transformers expects (batch, seq_len, heads, head_dim) for the output
|
||||
# So we need to transpose dimensions 1 and 2
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def patch_sageattn():
|
||||
"""Patch SageAttention for use with transformers."""
|
||||
|
||||
_check_sageattn_imported()
|
||||
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
# Replace flash attention with sage attention
|
||||
ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", sage_attention_forward)
|
||||
|
||||
# Note: New method after transformers refactor to use ALL_MASK_ATTENTION_FUNCTIONS
|
||||
# Register sage_attention with the global attention interface
|
||||
# ALL_ATTENTION_FUNCTIONS.register("sage_attention", sage_attention_forward)
|
||||
|
||||
# from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, flash_attention_mask
|
||||
|
||||
# ALL_MASK_ATTENTION_FUNCTIONS.register("sage_attention", flash_attention_mask)
|
||||
|
||||
LOG.info("SageAttention patched successfully")
|
||||
@@ -609,6 +609,12 @@ class AxolotlInputConfig(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Whether to use bettertransformers"},
|
||||
)
|
||||
sage_attention: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to use SageAttention https://github.com/thu-ml/SageAttention"
|
||||
},
|
||||
)
|
||||
|
||||
eager_attention: bool | None = None
|
||||
|
||||
@@ -1120,6 +1126,27 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sageattn_wo_sample_packing(cls, data):
|
||||
if (not data.get("sample_packing", False)) and data.get("sage_attention"):
|
||||
if not data.get("pad_to_sequence_len", False):
|
||||
LOG.warning(
|
||||
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
|
||||
"This is because there has been signs that the loss explodes after a few steps."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sageattn_fft(cls, data):
|
||||
if (not data.get("adapter", False)) and data.get("sage_attention"):
|
||||
LOG.warning(
|
||||
"We found loss to drop to 0 with SageAttention full finetuning."
|
||||
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""Wrapper to valdiate GPU capabilities with the configured options"""
|
||||
@@ -1176,6 +1203,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_compute_capability_w_sageattn(cls, data):
|
||||
if (
|
||||
data.get("sage_attention")
|
||||
and data.get("capabilities")
|
||||
and data.get("capabilities").get("compute_capability")
|
||||
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
|
||||
):
|
||||
raise ValueError(
|
||||
"SageAttention supports compute capability between sm_80 and sm_120. "
|
||||
"Please use a different attention implementation."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_multigpu_unsloth(cls, data):
|
||||
|
||||
@@ -166,9 +166,10 @@ class AttentionValidationMixin:
|
||||
fields = (
|
||||
"xformers_attention",
|
||||
"sdp_attention",
|
||||
"s2_attention",
|
||||
# "s2_attention", # requires both FA and this to be enabled
|
||||
"flash_attention",
|
||||
"flex_attention",
|
||||
"sage_attention",
|
||||
)
|
||||
non_empty_count = sum(1 for field in fields if data.get(field))
|
||||
|
||||
@@ -185,9 +186,10 @@ class AttentionValidationMixin:
|
||||
and not data.get("sdp_attention")
|
||||
and not data.get("flex_attention")
|
||||
and not data.get("xformers_attention")
|
||||
and not data.get("sage_attention")
|
||||
):
|
||||
LOG.warning(
|
||||
"sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination."
|
||||
"sample_packing without flash, sdp, xformers, sage, or flex attention does not handle cross sample decontamination."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
Reference in New Issue
Block a user