From 1037ca3a978a681eff070a0648264e6abf79814d Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 23 Sep 2025 16:44:26 -0400 Subject: [PATCH] update to new api --- .../kernels/moe/tt_mg_gemm/mg_grouped_gemm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py b/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py index 58deaeddb..bd19136e9 100644 --- a/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py +++ b/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py @@ -85,14 +85,14 @@ def _kernel_mg_forward_hopper( # Process this group # Acquire hold on c_desc_ptr for TMA Store - tl.extra.cuda.experimental_device_tensormap_create2d( + tl.extra.cuda.tensormap.create2d( desc_ptr=c_desc_ptr, global_address=c_ptr + M_start * n_size, load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[m_size, n_size], element_ty=c_dtype, ) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr) # tiles for this group num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) @@ -215,14 +215,14 @@ def _kernel_mg_forward_tma( n_size = N # TMA Store prep - tl.extra.cuda.experimental_device_tensormap_create2d( + tl.extra.cuda.tensormap.create2d( desc_ptr=c_desc_ptr, global_address=c_ptr + M_start * N, load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[m_size, n_size], element_ty=c_dtype, ) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr) # tiles for this group num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) @@ -447,14 +447,14 @@ def _kernel_mg_dx_tma( group_num_tiles = num_m_tiles * num_k_tiles # TMA Store prep for [M, K] output - tl.extra.cuda.experimental_device_tensormap_create2d( + tl.extra.cuda.tensormap.create2d( desc_ptr=c_desc_ptr, global_address=grad_input_ptr + M_start * K, load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[m_size, K], element_ty=c_dtype, ) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr) while tbidx >= processed_tiles and tbidx < ( processed_tiles + group_num_tiles