diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 987bfeb6e..2c5049244 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -16,11 +16,13 @@ def available() -> bool: ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2]) if ver < (2, 8): return False - # Check for aten grouped mm ops - return hasattr(torch.ops, "aten") and ( - hasattr(torch.ops.aten, "_grouped_mm") - or hasattr(torch.ops.aten, "_scaled_grouped_mm") - ) + # Require Hopper+ (SM90) per torch error message and check op presence + if not torch.cuda.is_available(): + return False + major, minor = torch.cuda.get_device_capability() + if major < 9: + return False + return hasattr(torch.ops, "aten") and hasattr(torch.ops.aten, "_grouped_mm") except Exception: return False