update to new api
This commit is contained in:
@@ -85,14 +85,14 @@ def _kernel_mg_forward_hopper(
|
||||
# Process this group
|
||||
|
||||
# Acquire hold on c_desc_ptr for TMA Store
|
||||
tl.extra.cuda.experimental_device_tensormap_create2d(
|
||||
tl.extra.cuda.tensormap.create2d(
|
||||
desc_ptr=c_desc_ptr,
|
||||
global_address=c_ptr + M_start * n_size,
|
||||
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
global_size=[m_size, n_size],
|
||||
element_ty=c_dtype,
|
||||
)
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
||||
tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr)
|
||||
|
||||
# tiles for this group
|
||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||
@@ -215,14 +215,14 @@ def _kernel_mg_forward_tma(
|
||||
n_size = N
|
||||
|
||||
# TMA Store prep
|
||||
tl.extra.cuda.experimental_device_tensormap_create2d(
|
||||
tl.extra.cuda.tensormap.create2d(
|
||||
desc_ptr=c_desc_ptr,
|
||||
global_address=c_ptr + M_start * N,
|
||||
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
global_size=[m_size, n_size],
|
||||
element_ty=c_dtype,
|
||||
)
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
||||
tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr)
|
||||
|
||||
# tiles for this group
|
||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||
@@ -447,14 +447,14 @@ def _kernel_mg_dx_tma(
|
||||
group_num_tiles = num_m_tiles * num_k_tiles
|
||||
|
||||
# TMA Store prep for [M, K] output
|
||||
tl.extra.cuda.experimental_device_tensormap_create2d(
|
||||
tl.extra.cuda.tensormap.create2d(
|
||||
desc_ptr=c_desc_ptr,
|
||||
global_address=grad_input_ptr + M_start * K,
|
||||
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
global_size=[m_size, K],
|
||||
element_ty=c_dtype,
|
||||
)
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
||||
tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr)
|
||||
|
||||
while tbidx >= processed_tiles and tbidx < (
|
||||
processed_tiles + group_num_tiles
|
||||
|
||||
Reference in New Issue
Block a user