improve handling and error if fa3 requested but not installeD

This commit is contained in:
Wing Lian
2025-05-19 10:11:14 -07:00
parent d6f64a3684
commit 9bdf4b1c23

View File

@@ -636,6 +636,14 @@ class ModelLoader:
if torch.cuda.get_device_capability() >= (9, 0):
# FA3 is only available on Hopper GPUs and newer
use_fa3 = True
if not importlib.util.find_spec("flash_attn_interface"):
use_fa3 = False
if use_fa3 and not importlib.util.find_spec("flash_attn_interface"):
# this can happen when use_flash_attention_3 is explicity set to True
# and flash_attn_interface is not installed
raise ModuleNotFoundError(
"Please install the flash_attn_interface library to use Flash Attention 3.x"
)
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None:
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import (