hardening
This commit is contained in:
@@ -16,11 +16,13 @@ def available() -> bool:
|
||||
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
|
||||
if ver < (2, 8):
|
||||
return False
|
||||
# Check for aten grouped mm ops
|
||||
return hasattr(torch.ops, "aten") and (
|
||||
hasattr(torch.ops.aten, "_grouped_mm")
|
||||
or hasattr(torch.ops.aten, "_scaled_grouped_mm")
|
||||
)
|
||||
# Require Hopper+ (SM90) per torch error message and check op presence
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major < 9:
|
||||
return False
|
||||
return hasattr(torch.ops, "aten") and hasattr(torch.ops.aten, "_grouped_mm")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user