diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index a28234135..927fd77d8 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -636,6 +636,14 @@ class ModelLoader: if torch.cuda.get_device_capability() >= (9, 0): # FA3 is only available on Hopper GPUs and newer use_fa3 = True + if not importlib.util.find_spec("flash_attn_interface"): + use_fa3 = False + if use_fa3 and not importlib.util.find_spec("flash_attn_interface"): + # this can happen when use_flash_attention_3 is explicity set to True + # and flash_attn_interface is not installed + raise ModuleNotFoundError( + "Please install the flash_attn_interface library to use Flash Attention 3.x" + ) if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None: from flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flash_attn_interface import (