hardening

This commit is contained in:
Dan Saunders
2025-09-15 19:41:10 -04:00
parent f6ed8ddc01
commit fef47a5b7c

View File

@@ -16,11 +16,13 @@ def available() -> bool:
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
if ver < (2, 8):
return False
# Check for aten grouped mm ops
return hasattr(torch.ops, "aten") and (
hasattr(torch.ops.aten, "_grouped_mm")
or hasattr(torch.ops.aten, "_scaled_grouped_mm")
)
# Require Hopper+ (SM90) per torch error message and check op presence
if not torch.cuda.is_available():
return False
major, minor = torch.cuda.get_device_capability()
if major < 9:
return False
return hasattr(torch.ops, "aten") and hasattr(torch.ops.aten, "_grouped_mm")
except Exception:
return False