ci for fa3

This commit is contained in:
Wing Lian
2025-05-18 00:49:15 -07:00
parent fb5ef6d445
commit a064f1c9b4
6 changed files with 46 additions and 9 deletions

View File

@@ -629,6 +629,27 @@ class ModelLoader:
)
if self.cfg.flash_attention:
use_fa3 = False
if self.cfg.use_flash_attention_3 is True:
use_fa3 = True
elif self.cfg.use_flash_attention_3 == "auto":
if int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90:
# FA3 is only available on Hopper GPUs and newer
use_fa3 = True
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None:
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
transformers.modeling_flash_attention_utils.flash_attn_func = (
flash_attn_func_v3
)
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
flash_attn_varlen_func_v3
)
LOG.info("Switched to Flash Attention v3")
self.patch_attention()
if self.cfg.sample_packing and self.cfg.s2_attention:
@@ -699,6 +720,7 @@ class ModelLoader:
patch_mllama()
# TODO deprecate soon
if self.model_config.model_type == "btlm":
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn,
@@ -706,6 +728,7 @@ class ModelLoader:
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
# TODO deprecate soon
if (
self.model_config.model_type == "stablelm_epoch"
and self.cfg.sample_packing

View File

@@ -233,6 +233,7 @@ class AxolotlInputConfig(
flash_attn_fuse_qkv: bool | None = None
flash_attn_fuse_mlp: bool | None = None
flash_optimum: bool | None = None
use_flash_attention_3: Literal["auto"] | bool | None = "auto"
eager_attention: bool | None = None