hardening
This commit is contained in:
@@ -16,11 +16,13 @@ def available() -> bool:
|
|||||||
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
|
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
|
||||||
if ver < (2, 8):
|
if ver < (2, 8):
|
||||||
return False
|
return False
|
||||||
# Check for aten grouped mm ops
|
# Require Hopper+ (SM90) per torch error message and check op presence
|
||||||
return hasattr(torch.ops, "aten") and (
|
if not torch.cuda.is_available():
|
||||||
hasattr(torch.ops.aten, "_grouped_mm")
|
return False
|
||||||
or hasattr(torch.ops.aten, "_scaled_grouped_mm")
|
major, minor = torch.cuda.get_device_capability()
|
||||||
)
|
if major < 9:
|
||||||
|
return False
|
||||||
|
return hasattr(torch.ops, "aten") and hasattr(torch.ops.aten, "_grouped_mm")
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user