use get_device_capability since CI setting in cfg is unreliable

This commit is contained in:
Wing Lian
2025-05-18 10:04:25 -07:00
parent 323a9cb153
commit bb6464c4c6

View File

@@ -633,7 +633,7 @@ class ModelLoader:
if self.cfg.use_flash_attention_3 is True:
use_fa3 = True
elif self.cfg.use_flash_attention_3 == "auto":
if int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90:
if torch.cuda.get_device_capability() >= (9, 0):
# FA3 is only available on Hopper GPUs and newer
use_fa3 = True
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None: