use get_device_capability since CI setting in cfg is unreliable
This commit is contained in:
@@ -633,7 +633,7 @@ class ModelLoader:
|
||||
if self.cfg.use_flash_attention_3 is True:
|
||||
use_fa3 = True
|
||||
elif self.cfg.use_flash_attention_3 == "auto":
|
||||
if int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90:
|
||||
if torch.cuda.get_device_capability() >= (9, 0):
|
||||
# FA3 is only available on Hopper GPUs and newer
|
||||
use_fa3 = True
|
||||
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None:
|
||||
|
||||
Reference in New Issue
Block a user