improve handling and error if fa3 requested but not installeD
This commit is contained in:
@@ -636,6 +636,14 @@ class ModelLoader:
|
||||
if torch.cuda.get_device_capability() >= (9, 0):
|
||||
# FA3 is only available on Hopper GPUs and newer
|
||||
use_fa3 = True
|
||||
if not importlib.util.find_spec("flash_attn_interface"):
|
||||
use_fa3 = False
|
||||
if use_fa3 and not importlib.util.find_spec("flash_attn_interface"):
|
||||
# this can happen when use_flash_attention_3 is explicity set to True
|
||||
# and flash_attn_interface is not installed
|
||||
raise ModuleNotFoundError(
|
||||
"Please install the flash_attn_interface library to use Flash Attention 3.x"
|
||||
)
|
||||
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None:
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
|
||||
from flash_attn_interface import (
|
||||
|
||||
Reference in New Issue
Block a user