update to new api

This commit is contained in:
Dan Saunders
2025-09-23 16:44:26 -04:00
parent 6369dcd7b8
commit 1037ca3a97

View File

@@ -85,14 +85,14 @@ def _kernel_mg_forward_hopper(
# Process this group
# Acquire hold on c_desc_ptr for TMA Store
tl.extra.cuda.experimental_device_tensormap_create2d(
tl.extra.cuda.tensormap.create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start * n_size,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_dtype,
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr)
# tiles for this group
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
@@ -215,14 +215,14 @@ def _kernel_mg_forward_tma(
n_size = N
# TMA Store prep
tl.extra.cuda.experimental_device_tensormap_create2d(
tl.extra.cuda.tensormap.create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_dtype,
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr)
# tiles for this group
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
@@ -447,14 +447,14 @@ def _kernel_mg_dx_tma(
group_num_tiles = num_m_tiles * num_k_tiles
# TMA Store prep for [M, K] output
tl.extra.cuda.experimental_device_tensormap_create2d(
tl.extra.cuda.tensormap.create2d(
desc_ptr=c_desc_ptr,
global_address=grad_input_ptr + M_start * K,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
global_size=[m_size, K],
element_ty=c_dtype,
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr)
while tbidx >= processed_tiles and tbidx < (
processed_tiles + group_num_tiles