fix
This commit is contained in:
@@ -15,7 +15,6 @@ from typing import Dict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
|
||||||
from triton.runtime import driver # @manual
|
from triton.runtime import driver # @manual
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
@@ -71,10 +70,6 @@ class TmaDescriptorHelper:
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"TMA not supported on this device (requires Hopper or newer)"
|
"TMA not supported on this device (requires Hopper or newer)"
|
||||||
)
|
)
|
||||||
if "nv_tma_desc_type" not in dir(tl):
|
|
||||||
raise RuntimeError(
|
|
||||||
"TMA grid constant descriptors not supported in your Triton version"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.tma_size = tma_size
|
self.tma_size = tma_size
|
||||||
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
|
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
|
||||||
|
|||||||
Reference in New Issue
Block a user