From 4db7a21ff7000dc57b987f5ca7f71c491c171264 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 23 Sep 2025 18:03:41 -0400 Subject: [PATCH] fix --- .../kernels/moe/tt_mg_gemm/mg_grouped_gemm.py | 288 +++++++++++------- 1 file changed, 170 insertions(+), 118 deletions(-) diff --git a/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py b/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py index c056cb667..9e540701e 100644 --- a/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py +++ b/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py @@ -48,7 +48,6 @@ def _kernel_mg_forward_hopper( K: tl.constexpr, # config NUM_SMS: tl.constexpr, - TMA_SIZE: tl.constexpr, USE_EPILOGUE_SUBTILING: tl.constexpr, # tiles BLOCK_SIZE_M: tl.constexpr, @@ -56,27 +55,21 @@ def _kernel_mg_forward_hopper( BLOCK_SIZE_K: tl.constexpr, ) -> None: """Flat index style forward kernel for Hopper using tensor descriptors.""" - tbidx = tl.program_id(0) # thread block index + tbidx = tl.program_id(0) - c_dtype = c_ptr.dtype.element_ty # output dtype + c_dtype = c_ptr.dtype.element_ty n_size = N // G - # Compute the total rows spanned by all groups so descriptors cover valid ranges. - total_rows = M_TOTAL.to(tl.int32) - - stride_k = tl.full([], K, dtype=tl.int32) - stride_1 = tl.full([], 1, dtype=tl.int32) - a_desc = tl.make_tensor_descriptor( a_ptr, - shape=[total_rows, stride_k], - strides=[stride_k, stride_1], + shape=[M_TOTAL, K], + strides=[K, 1], block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], ) b_desc = tl.make_tensor_descriptor( b_ptr, - shape=[tl.full([], N, dtype=tl.int32), stride_k], - strides=[stride_k, stride_1], + shape=[N, K], + strides=[K, 1], block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], ) @@ -88,55 +81,86 @@ def _kernel_mg_forward_hopper( m_size = tl.load(m_sizes + g) M_end = M_start + m_size - if m_size > 0: - group_base = c_ptr + M_start * n_size - c_desc = tl.make_tensor_descriptor( - group_base, - shape=[m_size, tl.full([], n_size, dtype=tl.int32)], - strides=[tl.full([], n_size, dtype=tl.int32), stride_1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], - ) + if m_size <= 0: + continue - num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) - num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) - group_num_tiles = num_m_tiles * num_n_tiles + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles - while ( - tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles - ): - group_index = tbidx - processed_tiles + while tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles: + group_index = tbidx - processed_tiles - tile_m_index = group_index % num_m_tiles - tile_n_index = group_index // num_m_tiles + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M + rows_remaining = tl.maximum(rows_remaining, 0) + row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining - m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32) - n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) - global_n_offset = (g * n_size + n_offset).to(tl.int32) + cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N + col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining - for k_offset in range(0, K, BLOCK_SIZE_K): - a = a_desc.load([m_offset, k_offset]) - b = b_desc.load([global_n_offset, k_offset]) - accumulator += tl.dot(a, b.T) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + global_n_offset = (g * n_size + n_offset).to(tl.int32) - if USE_EPILOGUE_SUBTILING: - acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) - acc = tl.permute(acc, (0, 2, 1)) - acc0, acc1 = tl.split(acc) - c_desc.store([local_m_offset, n_offset], acc0.to(c_dtype)) - c_desc.store( - [local_m_offset, n_offset + BLOCK_SIZE_N // 2], - acc1.to(c_dtype), - ) - else: - c_desc.store([local_m_offset, n_offset], accumulator.to(c_dtype)) + for k_offset in range(0, K, BLOCK_SIZE_K): + k_remaining = K - k_offset + k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining - tbidx += NUM_SMS + a = a_desc.load([m_offset, k_offset]) + a_mask = row_mask[:, None] & k_mask[None, :] + a = tl.where(a_mask, a, tl.zeros_like(a)) - processed_tiles += group_num_tiles + b = b_desc.load([global_n_offset, k_offset]) + b_mask = col_mask[:, None] & k_mask[None, :] + b = tl.where(b_mask, b, tl.zeros_like(b)) + + accumulator += tl.dot(a, b.T) + + local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + + local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M) + row_store_mask = local_row_offsets < m_size + global_row = (M_start + local_row_offsets).to(tl.int32) + + local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + col_store_mask = local_col_offsets < n_size + + store_mask = row_store_mask[:, None] & col_store_mask[None, :] + + if USE_EPILOGUE_SUBTILING: + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + + col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2] + col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2] + ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :] + tl.store( + ptr0, + acc0.to(c_dtype), + mask=row_store_mask[:, None] & col_mask0[None, :], + ) + + col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :] + col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :] + ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :] + tl.store( + ptr1, + acc1.to(c_dtype), + mask=row_store_mask[:, None] & col_mask1[None, :], + ) + else: + ptr = c_ptr + global_row[:, None] * n_size + local_col_offsets[None, :] + tl.store(ptr, accumulator.to(c_dtype), mask=store_mask) + + tbidx += NUM_SMS + + processed_tiles += group_num_tiles """ @@ -174,22 +198,17 @@ def _kernel_mg_dx_tma( tbidx = tl.program_id(0) c_dtype = grad_input_ptr.dtype.element_ty - stride_1 = tl.full([], 1, dtype=tl.int32) - stride_n = tl.full([], N, dtype=tl.int32) - stride_k = tl.full([], K, dtype=tl.int32) - - total_rows = M_TOTAL.to(tl.int32) grad_output_desc = tl.make_tensor_descriptor( grad_output_ptr, - shape=[total_rows, stride_n], - strides=[stride_n, stride_1], + shape=[M_TOTAL, N], + strides=[N, 1], block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], ) w_desc = tl.make_tensor_descriptor( w_ptr, - shape=[stride_n, stride_k], - strides=[stride_k, stride_1], + shape=[N, K], + strides=[K, 1], block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], ) @@ -201,43 +220,60 @@ def _kernel_mg_dx_tma( m_size = tl.load(m_sizes + g) M_end = M_start + m_size - if m_size > 0: - grad_input_desc = tl.make_tensor_descriptor( - grad_input_ptr + M_start * K, - shape=[m_size, stride_k], - strides=[stride_k, stride_1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], - ) + if m_size <= 0: + continue - num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) - num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) - group_num_tiles = num_m_tiles * num_k_tiles + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + group_num_tiles = num_m_tiles * num_k_tiles - while ( - tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles - ): - group_index = tbidx - processed_tiles - tile_m_index = group_index % num_m_tiles - tile_k_index = group_index // num_m_tiles + while tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles: + group_index = tbidx - processed_tiles - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + tile_m_index = group_index % num_m_tiles + tile_k_index = group_index // num_m_tiles - m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32) - k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32) + rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M + rows_remaining = tl.maximum(rows_remaining, 0) + row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining - for n_offset in range(0, N, BLOCK_SIZE_N): - grad_y = grad_output_desc.load([m_offset, n_offset]) - w_tile = w_desc.load([n_offset, k_offset]) - accumulator += tl.dot(grad_y, w_tile) + k_offset = tile_k_index * BLOCK_SIZE_K + k_remaining_total = K - k_offset + k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total - local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) - grad_input_desc.store( - [local_m_offset, k_offset], accumulator.to(c_dtype) - ) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) - tbidx += NUM_SMS + m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32) - processed_tiles += group_num_tiles + for n_offset in range(0, N, BLOCK_SIZE_N): + n_remaining = N - n_offset + n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining + + grad_y = grad_output_desc.load([m_offset, n_offset]) + grad_y_mask = row_mask[:, None] & n_mask[None, :] + grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y)) + + w_tile = w_desc.load([n_offset, k_offset]) + w_mask = n_mask[:, None] & k_mask[None, :] + w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile)) + + accumulator += tl.dot(grad_y, w_tile) + + local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + row_store_mask = local_row_offsets < m_size + global_row = (M_start + local_row_offsets).to(tl.int32) + + col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K) + col_store_mask = col_offsets < K + + store_mask = row_store_mask[:, None] & col_store_mask[None, :] + + ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :] + tl.store(ptr, accumulator.to(c_dtype), mask=store_mask) + + tbidx += NUM_SMS + + processed_tiles += group_num_tiles @triton.autotune( @@ -268,30 +304,19 @@ def _kernel_mg_dw_tma( tbidx = tl.program_id(0) c_dtype = grad_weight_ptr.dtype.element_ty - stride_1 = tl.full([], 1, dtype=tl.int32) - stride_n = tl.full([], N, dtype=tl.int32) - stride_k = tl.full([], K, dtype=tl.int32) - - total_rows = M_TOTAL.to(tl.int32) x_desc = tl.make_tensor_descriptor( x_ptr, - shape=[total_rows, stride_k], - strides=[stride_k, stride_1], + shape=[M_TOTAL, K], + strides=[K, 1], block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], ) grad_output_desc = tl.make_tensor_descriptor( grad_output_ptr, - shape=[total_rows, stride_n], - strides=[stride_n, stride_1], + shape=[M_TOTAL, N], + strides=[N, 1], block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], ) - grad_weight_desc = tl.make_tensor_descriptor( - grad_weight_ptr, - shape=[stride_n, stride_k], - strides=[stride_k, stride_1], - block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], - ) num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) @@ -301,8 +326,13 @@ def _kernel_mg_dw_tma( tile_n_idx = tile_idx % num_n_tiles tile_k_idx = tile_idx // num_n_tiles - n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) - k_offset = (tile_k_idx * BLOCK_SIZE_K).to(tl.int32) + n_offset = tile_n_idx * BLOCK_SIZE_N + n_remaining = N - n_offset + n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining + + k_offset = tile_k_idx * BLOCK_SIZE_K + k_remaining = K - k_offset + k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) @@ -312,17 +342,40 @@ def _kernel_mg_dw_tma( m_size = tl.load(m_sizes + g) M_end = M_start + m_size - if m_size > 0: - for m_offset in range(0, m_size, BLOCK_SIZE_M): - m_global = (M_start + m_offset).to(tl.int32) - grad_block = grad_output_desc.load([m_global, n_offset]) - x_block = x_desc.load([m_global, k_offset]) - accumulator += tl.dot( - grad_block.to(tl.float32).T, - x_block.to(tl.float32), - ) + if m_size <= 0: + continue - grad_weight_desc.store([n_offset, k_offset], accumulator.to(c_dtype)) + for m_offset_local in range(0, m_size, BLOCK_SIZE_M): + rows_remaining = m_size - m_offset_local + rows_remaining = tl.maximum(rows_remaining, 0) + row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining + + m_offset = (M_start + m_offset_local).to(tl.int32) + + x_block = x_desc.load([m_offset, k_offset]) + x_mask = row_mask[:, None] & k_mask[None, :] + x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block)) + + grad_block = grad_output_desc.load([m_offset, n_offset]) + grad_mask = row_mask[:, None] & n_mask[None, :] + grad_block = tl.where(grad_mask, grad_block, tl.zeros_like(grad_block)) + + contribution = tl.dot( + grad_block.to(tl.float32).T, + x_block.to(tl.float32), + ) + accumulator += contribution + + row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N) + row_store_mask = row_offsets < N + + col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K) + col_store_mask = col_offsets < K + + store_mask = row_store_mask[:, None] & col_store_mask[None, :] + + ptr = grad_weight_ptr + row_offsets[:, None] * K + col_offsets[None, :] + tl.store(ptr, accumulator.to(c_dtype), mask=store_mask) # ======== End Triton kernels ======== @@ -380,7 +433,6 @@ def grouped_gemm_forward( N, K, NUM_SMS, - TMA_SIZE=tma_size, USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING, ) return y