update to new api

This commit is contained in:
Dan Saunders
2025-09-23 16:44:26 -04:00
parent 6369dcd7b8
commit 1037ca3a97

View File

@@ -85,14 +85,14 @@ def _kernel_mg_forward_hopper(
# Process this group # Process this group
# Acquire hold on c_desc_ptr for TMA Store # Acquire hold on c_desc_ptr for TMA Store
tl.extra.cuda.experimental_device_tensormap_create2d( tl.extra.cuda.tensormap.create2d(
desc_ptr=c_desc_ptr, desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start * n_size, global_address=c_ptr + M_start * n_size,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size], global_size=[m_size, n_size],
element_ty=c_dtype, element_ty=c_dtype,
) )
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr)
# tiles for this group # tiles for this group
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
@@ -215,14 +215,14 @@ def _kernel_mg_forward_tma(
n_size = N n_size = N
# TMA Store prep # TMA Store prep
tl.extra.cuda.experimental_device_tensormap_create2d( tl.extra.cuda.tensormap.create2d(
desc_ptr=c_desc_ptr, desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start * N, global_address=c_ptr + M_start * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size], global_size=[m_size, n_size],
element_ty=c_dtype, element_ty=c_dtype,
) )
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr)
# tiles for this group # tiles for this group
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
@@ -447,14 +447,14 @@ def _kernel_mg_dx_tma(
group_num_tiles = num_m_tiles * num_k_tiles group_num_tiles = num_m_tiles * num_k_tiles
# TMA Store prep for [M, K] output # TMA Store prep for [M, K] output
tl.extra.cuda.experimental_device_tensormap_create2d( tl.extra.cuda.tensormap.create2d(
desc_ptr=c_desc_ptr, desc_ptr=c_desc_ptr,
global_address=grad_input_ptr + M_start * K, global_address=grad_input_ptr + M_start * K,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
global_size=[m_size, K], global_size=[m_size, K],
element_ty=c_dtype, element_ty=c_dtype,
) )
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr)
while tbidx >= processed_tiles and tbidx < ( while tbidx >= processed_tiles and tbidx < (
processed_tiles + group_num_tiles processed_tiles + group_num_tiles