diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 5f84079da..3dd5fc21b 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -633,7 +633,7 @@ class ModelLoader: if self.cfg.use_flash_attention_3 is True: use_fa3 = True elif self.cfg.use_flash_attention_3 == "auto": - if int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90: + if torch.cuda.get_device_capability() >= (9, 0): # FA3 is only available on Hopper GPUs and newer use_fa3 = True if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None: