From fef47a5b7c15c6bcbc3a101e34ab309e0518623c Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 15 Sep 2025 19:41:10 -0400 Subject: [PATCH] hardening --- src/axolotl/kernels/moe/torch_grouped.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 987bfeb6e..2c5049244 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -16,11 +16,13 @@ def available() -> bool: ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2]) if ver < (2, 8): return False - # Check for aten grouped mm ops - return hasattr(torch.ops, "aten") and ( - hasattr(torch.ops.aten, "_grouped_mm") - or hasattr(torch.ops.aten, "_scaled_grouped_mm") - ) + # Require Hopper+ (SM90) per torch error message and check op presence + if not torch.cuda.is_available(): + return False + major, minor = torch.cuda.get_device_capability() + if major < 9: + return False + return hasattr(torch.ops, "aten") and hasattr(torch.ops.aten, "_grouped_mm") except Exception: return False