update to new api
This commit is contained in:
@@ -85,14 +85,14 @@ def _kernel_mg_forward_hopper(
|
|||||||
# Process this group
|
# Process this group
|
||||||
|
|
||||||
# Acquire hold on c_desc_ptr for TMA Store
|
# Acquire hold on c_desc_ptr for TMA Store
|
||||||
tl.extra.cuda.experimental_device_tensormap_create2d(
|
tl.extra.cuda.tensormap.create2d(
|
||||||
desc_ptr=c_desc_ptr,
|
desc_ptr=c_desc_ptr,
|
||||||
global_address=c_ptr + M_start * n_size,
|
global_address=c_ptr + M_start * n_size,
|
||||||
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||||
global_size=[m_size, n_size],
|
global_size=[m_size, n_size],
|
||||||
element_ty=c_dtype,
|
element_ty=c_dtype,
|
||||||
)
|
)
|
||||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr)
|
||||||
|
|
||||||
# tiles for this group
|
# tiles for this group
|
||||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||||
@@ -215,14 +215,14 @@ def _kernel_mg_forward_tma(
|
|||||||
n_size = N
|
n_size = N
|
||||||
|
|
||||||
# TMA Store prep
|
# TMA Store prep
|
||||||
tl.extra.cuda.experimental_device_tensormap_create2d(
|
tl.extra.cuda.tensormap.create2d(
|
||||||
desc_ptr=c_desc_ptr,
|
desc_ptr=c_desc_ptr,
|
||||||
global_address=c_ptr + M_start * N,
|
global_address=c_ptr + M_start * N,
|
||||||
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||||
global_size=[m_size, n_size],
|
global_size=[m_size, n_size],
|
||||||
element_ty=c_dtype,
|
element_ty=c_dtype,
|
||||||
)
|
)
|
||||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr)
|
||||||
|
|
||||||
# tiles for this group
|
# tiles for this group
|
||||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||||
@@ -447,14 +447,14 @@ def _kernel_mg_dx_tma(
|
|||||||
group_num_tiles = num_m_tiles * num_k_tiles
|
group_num_tiles = num_m_tiles * num_k_tiles
|
||||||
|
|
||||||
# TMA Store prep for [M, K] output
|
# TMA Store prep for [M, K] output
|
||||||
tl.extra.cuda.experimental_device_tensormap_create2d(
|
tl.extra.cuda.tensormap.create2d(
|
||||||
desc_ptr=c_desc_ptr,
|
desc_ptr=c_desc_ptr,
|
||||||
global_address=grad_input_ptr + M_start * K,
|
global_address=grad_input_ptr + M_start * K,
|
||||||
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||||
global_size=[m_size, K],
|
global_size=[m_size, K],
|
||||||
element_ty=c_dtype,
|
element_ty=c_dtype,
|
||||||
)
|
)
|
||||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr)
|
||||||
|
|
||||||
while tbidx >= processed_tiles and tbidx < (
|
while tbidx >= processed_tiles and tbidx < (
|
||||||
processed_tiles + group_num_tiles
|
processed_tiles + group_num_tiles
|
||||||
|
|||||||
Reference in New Issue
Block a user