From 9bdf4b1c23dac954704da0ad1a2ba0f903adced7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 19 May 2025 10:11:14 -0700 Subject: [PATCH] improve handling and error if fa3 requested but not installeD --- src/axolotl/utils/models.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index a28234135..927fd77d8 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -636,6 +636,14 @@ class ModelLoader: if torch.cuda.get_device_capability() >= (9, 0): # FA3 is only available on Hopper GPUs and newer use_fa3 = True + if not importlib.util.find_spec("flash_attn_interface"): + use_fa3 = False + if use_fa3 and not importlib.util.find_spec("flash_attn_interface"): + # this can happen when use_flash_attention_3 is explicity set to True + # and flash_attn_interface is not installed + raise ModuleNotFoundError( + "Please install the flash_attn_interface library to use Flash Attention 3.x" + ) if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None: from flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flash_attn_interface import (