feat: add sageattention (#2823) [skip ci]
* feat: add sageattention * feat: call path on pre model load * fix: patch to use register to correct var * fix: add strict check import at start * chore: fix comments * chore: refactor * feat: add capability check * fix: missed underscore * fix: let sageattention use FA backend in transformers * feat: update sage attention for attention mask and position ids * feat: allow sample packing but add warning without packing * fix: loss hitting 0 with packing and attention mask note * feat: downcast embeds if sage attention too * feat: add config validation * feat: add attention docs * chore: docs
This commit is contained in:
@@ -320,6 +320,7 @@ website:
|
|||||||
- docs/multipack.qmd
|
- docs/multipack.qmd
|
||||||
- docs/mixed_precision.qmd
|
- docs/mixed_precision.qmd
|
||||||
- docs/optimizers.qmd
|
- docs/optimizers.qmd
|
||||||
|
- docs/attention.qmd
|
||||||
|
|
||||||
- section: "Advanced Features"
|
- section: "Advanced Features"
|
||||||
contents:
|
contents:
|
||||||
|
|||||||
140
docs/attention.qmd
Normal file
140
docs/attention.qmd
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
---
|
||||||
|
title: Attention
|
||||||
|
description: Supported attention modules in Axolotl
|
||||||
|
---
|
||||||
|
|
||||||
|
## SDP Attention
|
||||||
|
|
||||||
|
This is the default built-in attention in PyTorch.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sdp_attention: true
|
||||||
|
```
|
||||||
|
|
||||||
|
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||||
|
|
||||||
|
## Flash Attention 2
|
||||||
|
|
||||||
|
Uses efficient kernels to compute attention.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
flash_attention: true
|
||||||
|
```
|
||||||
|
|
||||||
|
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
||||||
|
|
||||||
|
### Nvidia
|
||||||
|
|
||||||
|
Requirements: Ampere, Ada, or Hopper GPUs
|
||||||
|
|
||||||
|
Note: For Turing GPUs or lower, please use other attention methods.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install flash-attn --no-build-isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
|
||||||
|
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
#### Flash Attention 3
|
||||||
|
|
||||||
|
Requirements: Hopper only and CUDA 12.8 (recommended)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||||
|
cd flash-attention/hopper
|
||||||
|
|
||||||
|
python setup.py install
|
||||||
|
```
|
||||||
|
|
||||||
|
### AMD
|
||||||
|
|
||||||
|
Requirements: ROCm 6.0 and above.
|
||||||
|
|
||||||
|
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
||||||
|
|
||||||
|
## Flex Attention
|
||||||
|
|
||||||
|
A flexible PyTorch API for attention used in combination with `torch.compile`.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
flex_attention: true
|
||||||
|
|
||||||
|
# recommended
|
||||||
|
torch_compile: true
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
|
||||||
|
We recommend using latest stable version of PyTorch for best performance.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
|
||||||
|
|
||||||
|
## SageAttention
|
||||||
|
|
||||||
|
Attention kernels with QK Int8 and PV FP16 accumulator.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sage_attention: true
|
||||||
|
```
|
||||||
|
|
||||||
|
Requirements: Ampere, Ada, or Hopper GPUs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install sageattention==2.2.0 --no-build-isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-warning}
|
||||||
|
|
||||||
|
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
|
||||||
|
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
|
## xFormers
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
xformers_attention: true
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
|
||||||
|
We recommend using with Turing GPUs or below (such as on Colab).
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
For more details: [xFormers](https://github.com/facebookresearch/xformers)
|
||||||
|
|
||||||
|
## Shifted Sparse Attention
|
||||||
|
|
||||||
|
::: {.callout-warning}
|
||||||
|
|
||||||
|
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
Requirements: LLaMA model architecture
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention: true
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
|
||||||
|
No sample packing support!
|
||||||
|
|
||||||
|
:::
|
||||||
@@ -338,7 +338,12 @@ class ModelLoader:
|
|||||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
|
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
|
||||||
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
|
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
|
||||||
(
|
(
|
||||||
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
|
(
|
||||||
|
needs_fa2_dtype
|
||||||
|
or self.cfg.flash_attention
|
||||||
|
or self.cfg.flex_attention
|
||||||
|
or self.cfg.sage_attention
|
||||||
|
)
|
||||||
and not self.is_qlora_and_fsdp_enabled
|
and not self.is_qlora_and_fsdp_enabled
|
||||||
)
|
)
|
||||||
or (
|
or (
|
||||||
@@ -612,6 +617,10 @@ class ModelLoader:
|
|||||||
elif self.cfg.sdp_attention:
|
elif self.cfg.sdp_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "sdpa"
|
self.model_kwargs["attn_implementation"] = "sdpa"
|
||||||
self.model_config._attn_implementation = "sdpa"
|
self.model_config._attn_implementation = "sdpa"
|
||||||
|
elif self.cfg.sage_attention:
|
||||||
|
# sets FA2 attention to re-use same internal handling like masking
|
||||||
|
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
|
self.model_config._attn_implementation = "flash_attention_2"
|
||||||
elif self.cfg.eager_attention:
|
elif self.cfg.eager_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "eager"
|
self.model_kwargs["attn_implementation"] = "eager"
|
||||||
self.model_config._attn_implementation = "eager"
|
self.model_config._attn_implementation = "eager"
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ class PatchManager:
|
|||||||
# self._apply_flex_attention_patches()
|
# self._apply_flex_attention_patches()
|
||||||
self._apply_flash_attention_patches()
|
self._apply_flash_attention_patches()
|
||||||
self._apply_chunked_cross_entropy_patch()
|
self._apply_chunked_cross_entropy_patch()
|
||||||
|
self._apply_sageattn_patches()
|
||||||
self._apply_fsdp_patches()
|
self._apply_fsdp_patches()
|
||||||
self._apply_adapter_patches()
|
self._apply_adapter_patches()
|
||||||
self._apply_model_specific_patches()
|
self._apply_model_specific_patches()
|
||||||
@@ -201,6 +202,13 @@ class PatchManager:
|
|||||||
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||||
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||||
|
|
||||||
|
def _apply_sageattn_patches(self):
|
||||||
|
"""Apply patches for SageAttention."""
|
||||||
|
if self.cfg.sage_attention:
|
||||||
|
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
|
||||||
|
|
||||||
|
patch_sageattn()
|
||||||
|
|
||||||
def _apply_model_specific_patches(self):
|
def _apply_model_specific_patches(self):
|
||||||
"""Apply patches specific to model architectures."""
|
"""Apply patches specific to model architectures."""
|
||||||
if (
|
if (
|
||||||
|
|||||||
211
src/axolotl/monkeypatch/attention/sage_attn.py
Normal file
211
src/axolotl/monkeypatch/attention/sage_attn.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Monkeypatch for SageAttention for use with transformers.
|
||||||
|
|
||||||
|
https://github.com/thu-ml/SageAttention/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.integrations.sdpa_attention import repeat_kv
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
sageattn = None # pylint: disable=invalid-name
|
||||||
|
sageattn_varlen = None # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def _is_sageattn_available():
|
||||||
|
"""Determine if SageAttention is available"""
|
||||||
|
try:
|
||||||
|
import sageattention # noqa: F401 # pylint: disable=unused-import
|
||||||
|
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if _is_sageattn_available():
|
||||||
|
# import sageattn here if available
|
||||||
|
from sageattention import sageattn, sageattn_varlen
|
||||||
|
|
||||||
|
|
||||||
|
def _check_sageattn_imported():
|
||||||
|
"""Check if SageAttention is imported. Raises an ImportError if not."""
|
||||||
|
if sageattn is None:
|
||||||
|
raise ImportError(
|
||||||
|
"SageAttention is not installed. Please install it from source: "
|
||||||
|
"`pip install git+https://github.com/thu-ml/SageAttention.git@1718ddc06dbc694bcf3c6b49ac28c1921aa2d8bd`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sage_attention_forward(
|
||||||
|
module: torch.nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
scaling: float | None = None,
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, None]:
|
||||||
|
"""
|
||||||
|
Forward pass for SageAttention compatible with transformers attention interfaces.
|
||||||
|
|
||||||
|
https://github.com/thu-ml/SageAttention/
|
||||||
|
"""
|
||||||
|
|
||||||
|
_check_sageattn_imported()
|
||||||
|
|
||||||
|
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"SageAttention does not support `output_attentions=True` or `head_mask`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# The base sageattn API does not support dropout.
|
||||||
|
if dropout > 0.0:
|
||||||
|
raise NotImplementedError("SageAttention does not support dropout.")
|
||||||
|
|
||||||
|
# Handle Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)
|
||||||
|
if hasattr(module, "num_key_value_groups"):
|
||||||
|
key = repeat_kv(key, module.num_key_value_groups)
|
||||||
|
value = repeat_kv(value, module.num_key_value_groups)
|
||||||
|
|
||||||
|
# Calculate is_causal following transformers
|
||||||
|
assert is_causal is not False, "is_causal must be True or None"
|
||||||
|
is_causal = True
|
||||||
|
|
||||||
|
position_ids = kwargs.get("position_ids", None)
|
||||||
|
query_length = query.shape[2]
|
||||||
|
|
||||||
|
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||||
|
cu_seqlens_k = kwargs.get("cu_seqlens_k", None)
|
||||||
|
max_length_q = kwargs.get("max_length_q", None)
|
||||||
|
max_length_k = kwargs.get("max_length_k", None)
|
||||||
|
|
||||||
|
# Sample packing uses position_ids, so we check for it first
|
||||||
|
if position_ids is not None and (
|
||||||
|
max_length_q is not None
|
||||||
|
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
|
||||||
|
):
|
||||||
|
# transpose inputs to NHD layout for use with FA2 utils
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
value = value.transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size = query.size(0)
|
||||||
|
|
||||||
|
from transformers.modeling_flash_attention_utils import (
|
||||||
|
prepare_fa2_from_position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cu_seqlens_q is None or cu_seqlens_k is None:
|
||||||
|
query, key, value, indices_q, cu_seq_lens, max_seq_lens = (
|
||||||
|
prepare_fa2_from_position_ids(query, key, value, position_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_length_q, max_length_k = max_seq_lens
|
||||||
|
|
||||||
|
else:
|
||||||
|
query = query.reshape(-1, query.size(-2), query.size(-1))
|
||||||
|
key = key.reshape(-1, key.size(-2), key.size(-1))
|
||||||
|
value = value.reshape(-1, value.size(-2), value.size(-1))
|
||||||
|
|
||||||
|
attn_output_unpad = sageattn_varlen(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_length_q,
|
||||||
|
max_seqlen_k=max_length_k,
|
||||||
|
is_causal=is_causal,
|
||||||
|
sm_scale=scaling,
|
||||||
|
smooth_k=False, # reduces loss 0 / nan grad norms
|
||||||
|
tensor_layout="NHD",
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output_unpad.view(
|
||||||
|
batch_size, -1, attn_output_unpad.size(-2), attn_output_unpad.size(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif attention_mask is not None:
|
||||||
|
# NOTE: When used without `pad_to_sequence_len`, the loss becomes unstable after a few steps.
|
||||||
|
|
||||||
|
assert attention_mask.ndim == 2, "Attention mask must be 2D"
|
||||||
|
|
||||||
|
from transformers.modeling_flash_attention_utils import (
|
||||||
|
_upad_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
# transpose inputs to NHD layout for use with FA2 utils
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
value = value.transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size = query.shape[0]
|
||||||
|
|
||||||
|
query, key, value, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
|
||||||
|
query, key, value, attention_mask, query_length
|
||||||
|
)
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_seqlen_q, max_seqlen_k = max_seq_lens
|
||||||
|
|
||||||
|
attn_output_unpad = sageattn_varlen(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
is_causal=is_causal,
|
||||||
|
sm_scale=scaling,
|
||||||
|
tensor_layout="NHD",
|
||||||
|
)
|
||||||
|
|
||||||
|
from flash_attn.bert_padding import pad_input
|
||||||
|
|
||||||
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||||
|
else:
|
||||||
|
# Use standard sageattn
|
||||||
|
# The input layout for transformers models is (batch_size, num_heads, seq_len, head_dim),
|
||||||
|
# which corresponds to SageAttention's "HND" layout.
|
||||||
|
attn_output = sageattn(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
tensor_layout="HND",
|
||||||
|
is_causal=is_causal,
|
||||||
|
sm_scale=scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
# SageAttention with "HND" returns (batch, heads, seq_len, head_dim)
|
||||||
|
# Transformers expects (batch, seq_len, heads, head_dim) for the output
|
||||||
|
# So we need to transpose dimensions 1 and 2
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
|
def patch_sageattn():
|
||||||
|
"""Patch SageAttention for use with transformers."""
|
||||||
|
|
||||||
|
_check_sageattn_imported()
|
||||||
|
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
|
# Replace flash attention with sage attention
|
||||||
|
ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", sage_attention_forward)
|
||||||
|
|
||||||
|
# Note: New method after transformers refactor to use ALL_MASK_ATTENTION_FUNCTIONS
|
||||||
|
# Register sage_attention with the global attention interface
|
||||||
|
# ALL_ATTENTION_FUNCTIONS.register("sage_attention", sage_attention_forward)
|
||||||
|
|
||||||
|
# from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, flash_attention_mask
|
||||||
|
|
||||||
|
# ALL_MASK_ATTENTION_FUNCTIONS.register("sage_attention", flash_attention_mask)
|
||||||
|
|
||||||
|
LOG.info("SageAttention patched successfully")
|
||||||
@@ -609,6 +609,12 @@ class AxolotlInputConfig(
|
|||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Whether to use bettertransformers"},
|
json_schema_extra={"description": "Whether to use bettertransformers"},
|
||||||
)
|
)
|
||||||
|
sage_attention: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to use SageAttention https://github.com/thu-ml/SageAttention"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
eager_attention: bool | None = None
|
eager_attention: bool | None = None
|
||||||
|
|
||||||
@@ -1120,6 +1126,27 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_sageattn_wo_sample_packing(cls, data):
|
||||||
|
if (not data.get("sample_packing", False)) and data.get("sage_attention"):
|
||||||
|
if not data.get("pad_to_sequence_len", False):
|
||||||
|
LOG.warning(
|
||||||
|
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
|
||||||
|
"This is because there has been signs that the loss explodes after a few steps."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_sageattn_fft(cls, data):
|
||||||
|
if (not data.get("adapter", False)) and data.get("sage_attention"):
|
||||||
|
LOG.warning(
|
||||||
|
"We found loss to drop to 0 with SageAttention full finetuning."
|
||||||
|
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""Wrapper to valdiate GPU capabilities with the configured options"""
|
"""Wrapper to valdiate GPU capabilities with the configured options"""
|
||||||
@@ -1176,6 +1203,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_compute_capability_w_sageattn(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("sage_attention")
|
||||||
|
and data.get("capabilities")
|
||||||
|
and data.get("capabilities").get("compute_capability")
|
||||||
|
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"SageAttention supports compute capability between sm_80 and sm_120. "
|
||||||
|
"Please use a different attention implementation."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_multigpu_unsloth(cls, data):
|
def check_multigpu_unsloth(cls, data):
|
||||||
|
|||||||
@@ -166,9 +166,10 @@ class AttentionValidationMixin:
|
|||||||
fields = (
|
fields = (
|
||||||
"xformers_attention",
|
"xformers_attention",
|
||||||
"sdp_attention",
|
"sdp_attention",
|
||||||
"s2_attention",
|
# "s2_attention", # requires both FA and this to be enabled
|
||||||
"flash_attention",
|
"flash_attention",
|
||||||
"flex_attention",
|
"flex_attention",
|
||||||
|
"sage_attention",
|
||||||
)
|
)
|
||||||
non_empty_count = sum(1 for field in fields if data.get(field))
|
non_empty_count = sum(1 for field in fields if data.get(field))
|
||||||
|
|
||||||
@@ -185,9 +186,10 @@ class AttentionValidationMixin:
|
|||||||
and not data.get("sdp_attention")
|
and not data.get("sdp_attention")
|
||||||
and not data.get("flex_attention")
|
and not data.get("flex_attention")
|
||||||
and not data.get("xformers_attention")
|
and not data.get("xformers_attention")
|
||||||
|
and not data.get("sage_attention")
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination."
|
"sample_packing without flash, sdp, xformers, sage, or flex attention does not handle cross sample decontamination."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user