This commit is contained in:
Dan Saunders
2025-09-23 20:21:22 +00:00
parent a81612305c
commit 6369dcd7b8

View File

@@ -15,7 +15,6 @@ from typing import Dict
import torch
import triton
import triton.language as tl
from triton.runtime import driver # @manual
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
@@ -71,10 +70,6 @@ class TmaDescriptorHelper:
raise RuntimeError(
"TMA not supported on this device (requires Hopper or newer)"
)
if "nv_tma_desc_type" not in dir(tl):
raise RuntimeError(
"TMA grid constant descriptors not supported in your Triton version"
)
self.tma_size = tma_size
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor