ci for fa3
This commit is contained in:
@@ -629,6 +629,27 @@ class ModelLoader:
|
||||
)
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
use_fa3 = False
|
||||
if self.cfg.use_flash_attention_3 is True:
|
||||
use_fa3 = True
|
||||
elif self.cfg.use_flash_attention_3 == "auto":
|
||||
if int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90:
|
||||
# FA3 is only available on Hopper GPUs and newer
|
||||
use_fa3 = True
|
||||
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None:
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
|
||||
from flash_attn_interface import (
|
||||
flash_attn_varlen_func as flash_attn_varlen_func_v3,
|
||||
)
|
||||
|
||||
transformers.modeling_flash_attention_utils.flash_attn_func = (
|
||||
flash_attn_func_v3
|
||||
)
|
||||
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
||||
flash_attn_varlen_func_v3
|
||||
)
|
||||
LOG.info("Switched to Flash Attention v3")
|
||||
|
||||
self.patch_attention()
|
||||
|
||||
if self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
@@ -699,6 +720,7 @@ class ModelLoader:
|
||||
|
||||
patch_mllama()
|
||||
|
||||
# TODO deprecate soon
|
||||
if self.model_config.model_type == "btlm":
|
||||
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
||||
replace_btlm_attn_with_flash_attn,
|
||||
@@ -706,6 +728,7 @@ class ModelLoader:
|
||||
|
||||
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
|
||||
|
||||
# TODO deprecate soon
|
||||
if (
|
||||
self.model_config.model_type == "stablelm_epoch"
|
||||
and self.cfg.sample_packing
|
||||
|
||||
@@ -233,6 +233,7 @@ class AxolotlInputConfig(
|
||||
flash_attn_fuse_qkv: bool | None = None
|
||||
flash_attn_fuse_mlp: bool | None = None
|
||||
flash_optimum: bool | None = None
|
||||
use_flash_attention_3: Literal["auto"] | bool | None = "auto"
|
||||
|
||||
eager_attention: bool | None = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user