diff --git a/_quarto.yml b/_quarto.yml index fe3a76e53..7de2be6a7 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -320,6 +320,7 @@ website: - docs/multipack.qmd - docs/mixed_precision.qmd - docs/optimizers.qmd + - docs/attention.qmd - section: "Advanced Features" contents: diff --git a/docs/attention.qmd b/docs/attention.qmd new file mode 100644 index 000000000..21004277e --- /dev/null +++ b/docs/attention.qmd @@ -0,0 +1,140 @@ +--- +title: Attention +description: Supported attention modules in Axolotl +--- + +## SDP Attention + +This is the default built-in attention in PyTorch. + +```yaml +sdp_attention: true +``` + +For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) + +## Flash Attention 2 + +Uses efficient kernels to compute attention. + +```yaml +flash_attention: true +``` + +For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/) + +### Nvidia + +Requirements: Ampere, Ada, or Hopper GPUs + +Note: For Turing GPUs or lower, please use other attention methods. + +```bash +pip install flash-attn --no-build-isolation +``` + +::: {.callout-tip} + +If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version. + +::: + +#### Flash Attention 3 + +Requirements: Hopper only and CUDA 12.8 (recommended) + +```bash +git clone https://github.com/Dao-AILab/flash-attention.git +cd flash-attention/hopper + +python setup.py install +``` + +### AMD + +Requirements: ROCm 6.0 and above. + +See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support). + +## Flex Attention + +A flexible PyTorch API for attention used in combination with `torch.compile`. + +```yaml +flex_attention: true + +# recommended +torch_compile: true +``` + +::: {.callout-note} + +We recommend using latest stable version of PyTorch for best performance. + +::: + +For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/) + +## SageAttention + +Attention kernels with QK Int8 and PV FP16 accumulator. + +```yaml +sage_attention: true +``` + +Requirements: Ampere, Ada, or Hopper GPUs + +```bash +pip install sageattention==2.2.0 --no-build-isolation +``` + +::: {.callout-warning} + +Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198). + +::: + +For more details: [Sage Attention](https://github.com/thu-ml/SageAttention) + +::: {.callout-note} + +We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue. + +::: + + +## xFormers + +```yaml +xformers_attention: true +``` + +::: {.callout-tip} + +We recommend using with Turing GPUs or below (such as on Colab). + +::: + +For more details: [xFormers](https://github.com/facebookresearch/xformers) + +## Shifted Sparse Attention + +::: {.callout-warning} + +We plan to deprecate this! If you use this feature, we recommend switching to methods above. + +::: + +Requirements: LLaMA model architecture + +```yaml +flash_attention: true +s2_attention: true +``` + +::: {.callout-tip} + +No sample packing support! + +::: diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 75684c1ae..6c8885526 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -338,7 +338,12 @@ class ModelLoader: # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so # we need to convert them back to fp16/bf16 for flash-attn compatibility. ( - (needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) + ( + needs_fa2_dtype + or self.cfg.flash_attention + or self.cfg.flex_attention + or self.cfg.sage_attention + ) and not self.is_qlora_and_fsdp_enabled ) or ( @@ -612,6 +617,10 @@ class ModelLoader: elif self.cfg.sdp_attention: self.model_kwargs["attn_implementation"] = "sdpa" self.model_config._attn_implementation = "sdpa" + elif self.cfg.sage_attention: + # sets FA2 attention to re-use same internal handling like masking + self.model_kwargs["attn_implementation"] = "flash_attention_2" + self.model_config._attn_implementation = "flash_attention_2" elif self.cfg.eager_attention: self.model_kwargs["attn_implementation"] = "eager" self.model_config._attn_implementation = "eager" diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 30c3ba0fd..3cf8bbd20 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -96,6 +96,7 @@ class PatchManager: # self._apply_flex_attention_patches() self._apply_flash_attention_patches() self._apply_chunked_cross_entropy_patch() + self._apply_sageattn_patches() self._apply_fsdp_patches() self._apply_adapter_patches() self._apply_model_specific_patches() @@ -201,6 +202,13 @@ class PatchManager: flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} patch_flex_wrapper(**flex_attn_compile_kwargs) + def _apply_sageattn_patches(self): + """Apply patches for SageAttention.""" + if self.cfg.sage_attention: + from axolotl.monkeypatch.attention.sage_attn import patch_sageattn + + patch_sageattn() + def _apply_model_specific_patches(self): """Apply patches specific to model architectures.""" if ( diff --git a/src/axolotl/monkeypatch/attention/sage_attn.py b/src/axolotl/monkeypatch/attention/sage_attn.py new file mode 100644 index 000000000..cc9fdb94d --- /dev/null +++ b/src/axolotl/monkeypatch/attention/sage_attn.py @@ -0,0 +1,211 @@ +""" +Monkeypatch for SageAttention for use with transformers. + +https://github.com/thu-ml/SageAttention/ +""" + +import torch +from transformers.integrations.sdpa_attention import repeat_kv + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +sageattn = None # pylint: disable=invalid-name +sageattn_varlen = None # pylint: disable=invalid-name + + +def _is_sageattn_available(): + """Determine if SageAttention is available""" + try: + import sageattention # noqa: F401 # pylint: disable=unused-import + + return True + except ImportError: + return False + + +if _is_sageattn_available(): + # import sageattn here if available + from sageattention import sageattn, sageattn_varlen + + +def _check_sageattn_imported(): + """Check if SageAttention is imported. Raises an ImportError if not.""" + if sageattn is None: + raise ImportError( + "SageAttention is not installed. Please install it from source: " + "`pip install git+https://github.com/thu-ml/SageAttention.git@1718ddc06dbc694bcf3c6b49ac28c1921aa2d8bd`" + ) + + +def sage_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None = None, + dropout: float = 0.0, + scaling: float | None = None, + is_causal: bool | None = None, + **kwargs, +) -> tuple[torch.Tensor, None]: + """ + Forward pass for SageAttention compatible with transformers attention interfaces. + + https://github.com/thu-ml/SageAttention/ + """ + + _check_sageattn_imported() + + if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None: + raise NotImplementedError( + "SageAttention does not support `output_attentions=True` or `head_mask`." + ) + + # The base sageattn API does not support dropout. + if dropout > 0.0: + raise NotImplementedError("SageAttention does not support dropout.") + + # Handle Grouped-Query Attention (GQA) and Multi-Query Attention (MQA) + if hasattr(module, "num_key_value_groups"): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + + # Calculate is_causal following transformers + assert is_causal is not False, "is_causal must be True or None" + is_causal = True + + position_ids = kwargs.get("position_ids", None) + query_length = query.shape[2] + + cu_seqlens_q = kwargs.get("cu_seqlens_q", None) + cu_seqlens_k = kwargs.get("cu_seqlens_k", None) + max_length_q = kwargs.get("max_length_q", None) + max_length_k = kwargs.get("max_length_k", None) + + # Sample packing uses position_ids, so we check for it first + if position_ids is not None and ( + max_length_q is not None + or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()) + ): + # transpose inputs to NHD layout for use with FA2 utils + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + batch_size = query.size(0) + + from transformers.modeling_flash_attention_utils import ( + prepare_fa2_from_position_ids, + ) + + if cu_seqlens_q is None or cu_seqlens_k is None: + query, key, value, indices_q, cu_seq_lens, max_seq_lens = ( + prepare_fa2_from_position_ids(query, key, value, position_ids) + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_length_q, max_length_k = max_seq_lens + + else: + query = query.reshape(-1, query.size(-2), query.size(-1)) + key = key.reshape(-1, key.size(-2), key.size(-1)) + value = value.reshape(-1, value.size(-2), value.size(-1)) + + attn_output_unpad = sageattn_varlen( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + is_causal=is_causal, + sm_scale=scaling, + smooth_k=False, # reduces loss 0 / nan grad norms + tensor_layout="NHD", + ) + + attn_output = attn_output_unpad.view( + batch_size, -1, attn_output_unpad.size(-2), attn_output_unpad.size(-1) + ) + + elif attention_mask is not None: + # NOTE: When used without `pad_to_sequence_len`, the loss becomes unstable after a few steps. + + assert attention_mask.ndim == 2, "Attention mask must be 2D" + + from transformers.modeling_flash_attention_utils import ( + _upad_input, + ) + + # transpose inputs to NHD layout for use with FA2 utils + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + batch_size = query.shape[0] + + query, key, value, indices_q, cu_seq_lens, max_seq_lens = _upad_input( + query, key, value, attention_mask, query_length + ) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_q, max_seqlen_k = max_seq_lens + + attn_output_unpad = sageattn_varlen( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + is_causal=is_causal, + sm_scale=scaling, + tensor_layout="NHD", + ) + + from flash_attn.bert_padding import pad_input + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + # Use standard sageattn + # The input layout for transformers models is (batch_size, num_heads, seq_len, head_dim), + # which corresponds to SageAttention's "HND" layout. + attn_output = sageattn( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + sm_scale=scaling, + ) + + # SageAttention with "HND" returns (batch, heads, seq_len, head_dim) + # Transformers expects (batch, seq_len, heads, head_dim) for the output + # So we need to transpose dimensions 1 and 2 + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, None + + +def patch_sageattn(): + """Patch SageAttention for use with transformers.""" + + _check_sageattn_imported() + + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + # Replace flash attention with sage attention + ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", sage_attention_forward) + + # Note: New method after transformers refactor to use ALL_MASK_ATTENTION_FUNCTIONS + # Register sage_attention with the global attention interface + # ALL_ATTENTION_FUNCTIONS.register("sage_attention", sage_attention_forward) + + # from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, flash_attention_mask + + # ALL_MASK_ATTENTION_FUNCTIONS.register("sage_attention", flash_attention_mask) + + LOG.info("SageAttention patched successfully") diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 653773273..d858fdbce 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -609,6 +609,12 @@ class AxolotlInputConfig( default=None, json_schema_extra={"description": "Whether to use bettertransformers"}, ) + sage_attention: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use SageAttention https://github.com/thu-ml/SageAttention" + }, + ) eager_attention: bool | None = None @@ -1120,6 +1126,27 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def check_sageattn_wo_sample_packing(cls, data): + if (not data.get("sample_packing", False)) and data.get("sage_attention"): + if not data.get("pad_to_sequence_len", False): + LOG.warning( + "We recommend turning on `pad_to_sequence_len` for SageAttention without packing." + "This is because there has been signs that the loss explodes after a few steps." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_sageattn_fft(cls, data): + if (not data.get("adapter", False)) and data.get("sage_attention"): + LOG.warning( + "We found loss to drop to 0 with SageAttention full finetuning." + "Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method." + ) + return data + class AxolotlConfigWCapabilities(AxolotlInputConfig): """Wrapper to valdiate GPU capabilities with the configured options""" @@ -1176,6 +1203,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): return data + @model_validator(mode="before") + @classmethod + def check_compute_capability_w_sageattn(cls, data): + if ( + data.get("sage_attention") + and data.get("capabilities") + and data.get("capabilities").get("compute_capability") + not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"] + ): + raise ValueError( + "SageAttention supports compute capability between sm_80 and sm_120. " + "Please use a different attention implementation." + ) + return data + @model_validator(mode="before") @classmethod def check_multigpu_unsloth(cls, data): diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 9f225b75e..bde367e0e 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -166,9 +166,10 @@ class AttentionValidationMixin: fields = ( "xformers_attention", "sdp_attention", - "s2_attention", + # "s2_attention", # requires both FA and this to be enabled "flash_attention", "flex_attention", + "sage_attention", ) non_empty_count = sum(1 for field in fields if data.get(field)) @@ -185,9 +186,10 @@ class AttentionValidationMixin: and not data.get("sdp_attention") and not data.get("flex_attention") and not data.get("xformers_attention") + and not data.get("sage_attention") ): LOG.warning( - "sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination." + "sample_packing without flash, sdp, xformers, sage, or flex attention does not handle cross sample decontamination." ) return data