diff --git a/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py b/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py index bd19136e9..c056cb667 100644 --- a/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py +++ b/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py @@ -18,7 +18,6 @@ import triton.language as tl from .tma_autotuning import ( _NV_CONFIGS, CudaUtils, - TmaDescriptorHelper, early_config_prune, ) @@ -37,11 +36,11 @@ logging.basicConfig( ) @triton.jit def _kernel_mg_forward_hopper( - a_desc_ptr, - b_desc_ptr, + a_ptr, + b_ptr, c_ptr, - workspace, m_sizes, + M_TOTAL, # problem sizes G: tl.constexpr, M_BUCKET: tl.constexpr, @@ -56,326 +55,87 @@ def _kernel_mg_forward_hopper( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ) -> None: - """ - Flat index style forward kernel for Hopper. - For simplicity, we always use TMA Load and TMA Store - """ + """Flat index style forward kernel for Hopper using tensor descriptors.""" tbidx = tl.program_id(0) # thread block index c_dtype = c_ptr.dtype.element_ty # output dtype - - c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store - - M_end = 0 - M_start = 0 - processed_tiles = 0 - # Size of individual weight matrix n_size = N // G - n_start = 0 + + # Compute the total rows spanned by all groups so descriptors cover valid ranges. + total_rows = M_TOTAL.to(tl.int32) + + stride_k = tl.full([], K, dtype=tl.int32) + stride_1 = tl.full([], 1, dtype=tl.int32) + + a_desc = tl.make_tensor_descriptor( + a_ptr, + shape=[total_rows, stride_k], + strides=[stride_k, stride_1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + b_desc = tl.make_tensor_descriptor( + b_ptr, + shape=[tl.full([], N, dtype=tl.int32), stride_k], + strides=[stride_k, stride_1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + + M_end = tl.full([], 0, dtype=tl.int32) + processed_tiles = 0 for g in range(G): - # Move down along groups - # reset to new M offset M_start = M_end m_size = tl.load(m_sizes + g) M_end = M_start + m_size - n_start = n_size * g if m_size > 0: - # Process this group - - # Acquire hold on c_desc_ptr for TMA Store - tl.extra.cuda.tensormap.create2d( - desc_ptr=c_desc_ptr, - global_address=c_ptr + M_start * n_size, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], - global_size=[m_size, n_size], - element_ty=c_dtype, + group_base = c_ptr + M_start * n_size + c_desc = tl.make_tensor_descriptor( + group_base, + shape=[m_size, tl.full([], n_size, dtype=tl.int32)], + strides=[tl.full([], n_size, dtype=tl.int32), stride_1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], ) - tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr) - # tiles for this group num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) group_num_tiles = num_m_tiles * num_n_tiles - while tbidx >= processed_tiles and tbidx < ( - processed_tiles + group_num_tiles + while ( + tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles ): group_index = tbidx - processed_tiles - # columnwise tile_m_index = group_index % num_m_tiles tile_n_index = group_index // num_m_tiles accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32) n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) - global_n_offset = (n_start + n_offset).to(tl.int32) + global_n_offset = (g * n_size + n_offset).to(tl.int32) for k_offset in range(0, K, BLOCK_SIZE_K): - # input block [M,K] - a = tl._experimental_descriptor_load( - a_desc_ptr, - [m_offset, k_offset], - [BLOCK_SIZE_M, BLOCK_SIZE_K], - c_dtype, - ) - # weight block [N, K] - b = tl._experimental_descriptor_load( - b_desc_ptr, - [global_n_offset, k_offset], - [BLOCK_SIZE_N, BLOCK_SIZE_K], - c_dtype, - ) - + a = a_desc.load([m_offset, k_offset]) + b = b_desc.load([global_n_offset, k_offset]) accumulator += tl.dot(a, b.T) - # Store using TMA - - m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) if USE_EPILOGUE_SUBTILING: acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) acc = tl.permute(acc, (0, 2, 1)) acc0, acc1 = tl.split(acc) - c0 = acc0.to(c_dtype) - tl._experimental_descriptor_store( - c_desc_ptr, c0, [m_offset, n_offset] - ) - c1 = acc1.to(c_dtype) - tl._experimental_descriptor_store( - c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2] + c_desc.store([local_m_offset, n_offset], acc0.to(c_dtype)) + c_desc.store( + [local_m_offset, n_offset + BLOCK_SIZE_N // 2], + acc1.to(c_dtype), ) else: - tl._experimental_descriptor_store( - c_desc_ptr, - accumulator.to(c_dtype), - [m_offset, n_offset], - ) - # move to next tile in group + c_desc.store([local_m_offset, n_offset], accumulator.to(c_dtype)) + tbidx += NUM_SMS - # Update the total tiles count for the next group - processed_tiles += group_num_tiles - -@triton.autotune( - configs=_NV_CONFIGS, - key=["G", "M_BUCKET", "N", "K"], - prune_configs_by={"early_config_prune": early_config_prune}, -) -@triton.jit -def _kernel_mg_forward_tma( - a_desc_ptr, - b_desc_ptr, - c_ptr, - workspace, - m_sizes, - a_scale_ptr, - b_scale_ptr, - # problem sizes - G: tl.constexpr, - M_BUCKET: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - # config - NUM_SMS: tl.constexpr, - USE_TMA_LOAD: tl.constexpr, - USE_TMA_STORE: tl.constexpr, - TMA_SIZE: tl.constexpr, - USE_FP8: tl.constexpr, - # tiles - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -) -> None: - """ - Flat index style forward kernel. - For simplicity, we always use TMA Load and TMA Store - """ - tbidx = tl.program_id(0) # thread block index - - c_dtype = c_ptr.dtype.element_ty - - c_desc_ptr = workspace + (tbidx * TMA_SIZE) - - M_end = 0 - processed_tiles = 0 - - for g in range(G): - # Move down along groups - # reset to new M offset - M_start = M_end - m_size = tl.load(m_sizes + g) - M_end = M_start + m_size - - if m_size > 0: - # Process this group - n_size = N - - # TMA Store prep - tl.extra.cuda.tensormap.create2d( - desc_ptr=c_desc_ptr, - global_address=c_ptr + M_start * N, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], - global_size=[m_size, n_size], - element_ty=c_dtype, - ) - tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr) - - # tiles for this group - num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) - num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) - group_num_tiles = num_m_tiles * num_n_tiles - - while tbidx >= processed_tiles and tbidx < ( - processed_tiles + group_num_tiles - ): - group_index = tbidx - processed_tiles - - tile_m_index = group_index % num_m_tiles - tile_n_index = group_index // num_m_tiles - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) - n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) - - for k_offset in range(0, K, BLOCK_SIZE_K): - # input block [M,K] - a = tl._experimental_descriptor_load( - a_desc_ptr, - [m_offset, k_offset], - [BLOCK_SIZE_M, BLOCK_SIZE_K], - c_dtype, - ) - # weight block [N, K] - b = tl._experimental_descriptor_load( - b_desc_ptr, - [n_offset, k_offset], - [BLOCK_SIZE_N, BLOCK_SIZE_K], - c_dtype, - ) - - accumulator += tl.dot(a, b.T) - - # Store using TMA - - m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) - # n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) - - tl._experimental_descriptor_store( - c_desc_ptr, - accumulator.to(c_dtype), - [m_offset, n_offset], - ) - - # Move to the next tile - tbidx += NUM_SMS - # Update the total tiles count for the next group - processed_tiles += group_num_tiles - - -@triton.autotune( - configs=_NV_CONFIGS, - key=["G", "M_BUCKET", "N", "K"], - prune_configs_by={"early_config_prune": early_config_prune}, -) -@triton.jit -def _kernel_mg_forward_no_tma( - a_ptr, - b_ptr, - c_ptr, - workspace, - m_sizes, - # problem sizes - G: tl.constexpr, - M_BUCKET: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - # config - NUM_SMS: tl.constexpr, - USE_TMA_LOAD: tl.constexpr, - USE_TMA_STORE: tl.constexpr, - TMA_SIZE: tl.constexpr, - # tiles - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -) -> None: - """ - Flat index style forward kernel. - For bc and Ampere, we never use TMA Load and TMA Store - """ - tbidx = tl.program_id(0) # thread block index - - c_dtype = c_ptr.dtype.element_ty - - M_end = 0 - processed_tiles = 0 - - for g in range(G): - # Move down along groups - # reset to new M offset - M_start = M_end - m_size = tl.load(m_sizes + g) - M_end = M_start + m_size - - if m_size > 0: - # Process this group - n_size = N - - # tiles for this group - num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) - num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) - group_num_tiles = num_m_tiles * num_n_tiles - - while tbidx >= processed_tiles and tbidx < ( - processed_tiles + group_num_tiles - ): - group_index = tbidx - processed_tiles - - tile_m_index = group_index % num_m_tiles - tile_n_index = group_index // num_m_tiles - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - - a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :] - b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :] - - for _ in range(0, K, BLOCK_SIZE_K): - # Load with bounds checking - a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) - b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) - - # Main matmul - accumulator += tl.dot(a, b.T) - - # Update pointers for next block - a_ptrs += BLOCK_SIZE_K - b_ptrs += BLOCK_SIZE_K - - # Store without TMA - offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - c = accumulator.to(c_dtype) - - tl.store( - c_ptr - + (M_start + offs_am[:, None]) * N # Row stride is N - + offs_bn[None, :], # Column offset - c, - mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size, - ) - # Move to the next tile - tbidx += NUM_SMS - # Update the total tiles count for the next group processed_tiles += group_num_tiles @@ -393,11 +153,11 @@ We compute gradients with respect to both input (`grad_x`) and weights (`grad_w` ) @triton.jit def _kernel_mg_dx_tma( - grad_output_desc_ptr, # [MG, N] - w_desc_ptr, # [N, K] - grad_input_ptr, # output grad_x [MG, K] - workspace, # for TMA store - m_sizes, # group sizes [G] + grad_output_ptr, + w_ptr, + grad_input_ptr, + m_sizes, + M_TOTAL, # problem sizes G: tl.constexpr, M_BUCKET: tl.constexpr, @@ -405,115 +165,81 @@ def _kernel_mg_dx_tma( K: tl.constexpr, # config NUM_SMS: tl.constexpr, - USE_TMA_LOAD: tl.constexpr, - USE_TMA_STORE: tl.constexpr, - TMA_SIZE: tl.constexpr, # tiles BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ) -> None: - """ - TMA-optimized kernel for computing gradients with respect to input (dx). - For the forward pass Y = X @ W.T, the backward for input is: - grad_X = grad_Y @ W - - This maps to [MG, N] @ [N, K] -> [MG, K] - - Key differences from forward: - 1. W is used directly and not transposed - 2. The reduction dimension is now N (not K) - 3. Output is [M, K] instead of [M, N] - """ - tbidx = tl.program_id(0) # thread block index + """Compute grad_input = grad_output @ w using tensor descriptors.""" + tbidx = tl.program_id(0) c_dtype = grad_input_ptr.dtype.element_ty - c_desc_ptr = workspace + (tbidx * TMA_SIZE) + stride_1 = tl.full([], 1, dtype=tl.int32) + stride_n = tl.full([], N, dtype=tl.int32) + stride_k = tl.full([], K, dtype=tl.int32) - M_end = 0 + total_rows = M_TOTAL.to(tl.int32) + + grad_output_desc = tl.make_tensor_descriptor( + grad_output_ptr, + shape=[total_rows, stride_n], + strides=[stride_n, stride_1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + w_desc = tl.make_tensor_descriptor( + w_ptr, + shape=[stride_n, stride_k], + strides=[stride_k, stride_1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + + M_end = tl.full([], 0, dtype=tl.int32) processed_tiles = 0 for g in range(G): - # Move down along groups - same as forward M_start = M_end m_size = tl.load(m_sizes + g) M_end = M_start + m_size if m_size > 0: - # Process this group - # tiles for this group - now producing [M, K] output + grad_input_desc = tl.make_tensor_descriptor( + grad_input_ptr + M_start * K, + shape=[m_size, stride_k], + strides=[stride_k, stride_1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) group_num_tiles = num_m_tiles * num_k_tiles - # TMA Store prep for [M, K] output - tl.extra.cuda.tensormap.create2d( - desc_ptr=c_desc_ptr, - global_address=grad_input_ptr + M_start * K, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], - global_size=[m_size, K], - element_ty=c_dtype, - ) - tl.extra.cuda.tensormap.fenceproxy_acquire(c_desc_ptr) - - while tbidx >= processed_tiles and tbidx < ( - processed_tiles + group_num_tiles + while ( + tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles ): group_index = tbidx - processed_tiles - - # Different tiling scheme for [M, K] output tile_m_index = group_index % num_m_tiles tile_k_index = group_index // num_m_tiles - # for grad_input block [M, K] accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) - # Position in full matrix - m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32) k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32) - # reduce along N dimension (instead of K in forward) for n_offset in range(0, N, BLOCK_SIZE_N): - # grad_output block [M, N] - grad_output = tl._experimental_descriptor_load( - grad_output_desc_ptr, - [m_offset, n_offset], - [BLOCK_SIZE_M, BLOCK_SIZE_N], - c_dtype, - ) + grad_y = grad_output_desc.load([m_offset, n_offset]) + w_tile = w_desc.load([n_offset, k_offset]) + accumulator += tl.dot(grad_y, w_tile) - # weight block [N, K] - no transpose needed - w = tl._experimental_descriptor_load( - w_desc_ptr, - [n_offset, k_offset], - [BLOCK_SIZE_N, BLOCK_SIZE_K], - c_dtype, - ) - - # grad_x = grad_output @ w - # reducing along N dimension - accumulator += tl.dot(grad_output, w) - - # Store using TMA - m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) - # k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32) - - tl._experimental_descriptor_store( - c_desc_ptr, - accumulator.to(c_dtype), - [m_offset, k_offset], + local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + grad_input_desc.store( + [local_m_offset, k_offset], accumulator.to(c_dtype) ) - # Move to the next tile tbidx += NUM_SMS - # Update the total tiles count for the next group processed_tiles += group_num_tiles -# ---- dw flat linear indexed ---- - - @triton.autotune( configs=_NV_CONFIGS, key=["G", "M_BUCKET", "N", "K"], @@ -521,11 +247,11 @@ def _kernel_mg_dx_tma( ) @triton.jit def _kernel_mg_dw_tma( - x_desc_ptr, # input descriptor [M_total, K] - grad_output_desc_ptr, # grad_output descriptor [M_total, N] - grad_weight_ptr, # output grad_w [N, K] - workspace, # workspace for TMA store - m_sizes, # group sizes [G] + x_ptr, + grad_output_ptr, + grad_weight_ptr, + m_sizes, + M_TOTAL, # problem sizes G: tl.constexpr, M_BUCKET: tl.constexpr, @@ -533,172 +259,73 @@ def _kernel_mg_dw_tma( K: tl.constexpr, # config NUM_SMS: tl.constexpr, - USE_TMA_LOAD: tl.constexpr, - USE_TMA_STORE: tl.constexpr, - TMA_SIZE: tl.constexpr, # tiles BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension + BLOCK_SIZE_M: tl.constexpr, ) -> None: - """ - Improved TMA-optimized kernel for computing gradients with respect to weights (dw). - Uses flat index structure similar to forward. - - For the forward pass Y = X @ W.T, - the backward for weights is: - grad_W = grad_Y.T @ X - - Where: - - grad_Y is [MG, N] - - X is [MG, K] - - grad_W is [N, K] - - we return [N,K] - """ - # Get thread block index l + """Compute grad_weight = grad_output.T @ x using tensor descriptors.""" tbidx = tl.program_id(0) - # Get output data type c_dtype = grad_weight_ptr.dtype.element_ty + stride_1 = tl.full([], 1, dtype=tl.int32) + stride_n = tl.full([], N, dtype=tl.int32) + stride_k = tl.full([], K, dtype=tl.int32) + + total_rows = M_TOTAL.to(tl.int32) + + x_desc = tl.make_tensor_descriptor( + x_ptr, + shape=[total_rows, stride_k], + strides=[stride_k, stride_1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + grad_output_desc = tl.make_tensor_descriptor( + grad_output_ptr, + shape=[total_rows, stride_n], + strides=[stride_n, stride_1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + grad_weight_desc = tl.make_tensor_descriptor( + grad_weight_ptr, + shape=[stride_n, stride_k], + strides=[stride_k, stride_1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) - # Calculate number of output tiles num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) - total_output_tiles = num_n_tiles * num_k_tiles + total_tiles = num_n_tiles * num_k_tiles - # Process tiles in strided manner across SMs - for tile_idx in range(tbidx, total_output_tiles, NUM_SMS): - # Calculate tile indices + for tile_idx in range(tbidx, total_tiles, NUM_SMS): tile_n_idx = tile_idx % num_n_tiles tile_k_idx = tile_idx // num_n_tiles - # Calculate global offsets - n_offset = tile_n_idx * BLOCK_SIZE_N - k_offset = tile_k_idx * BLOCK_SIZE_K + n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + k_offset = (tile_k_idx * BLOCK_SIZE_K).to(tl.int32) - # Initialize accumulator for this output tile [N, K] accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) - # Process each group - M_end = 0 + M_end = tl.full([], 0, dtype=tl.int32) for g in range(G): - # Get group boundaries M_start = M_end m_size = tl.load(m_sizes + g) M_end = M_start + m_size - # Only process if group is non-empty if m_size > 0: - # Process this group in chunks along the M dimension for m_offset in range(0, m_size, BLOCK_SIZE_M): - # Calculate actual block size (handling boundary) - m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset) + m_global = (M_start + m_offset).to(tl.int32) + grad_block = grad_output_desc.load([m_global, n_offset]) + x_block = x_desc.load([m_global, k_offset]) + accumulator += tl.dot( + grad_block.to(tl.float32).T, + x_block.to(tl.float32), + ) - # Only process if we have actual work to do - if m_block_size > 0: - # Global offset for this chunk - m_global_offset = M_start + m_offset - - if USE_TMA_LOAD: - # Load input chunk [M_chunk, K] using TMA - x_block = tl._experimental_descriptor_load( - x_desc_ptr, - [m_global_offset, k_offset], - [BLOCK_SIZE_M, BLOCK_SIZE_K], - c_dtype, - ) - - # Load grad_output chunk [M_chunk, N] using TMA - grad_output_block = tl._experimental_descriptor_load( - grad_output_desc_ptr, - [m_global_offset, n_offset], - [BLOCK_SIZE_M, BLOCK_SIZE_N], - c_dtype, - ) - - # Apply masks for valid regions - offs_m = tl.arange(0, BLOCK_SIZE_M) - m_mask = offs_m < m_block_size - - # Zero out invalid elements - x_block = tl.where(m_mask[:, None], x_block, 0.0) - grad_output_block = tl.where( - m_mask[:, None], grad_output_block, 0.0 - ) - else: - # Manual load with bounds checking - offs_m = tl.arange(0, BLOCK_SIZE_M) - offs_n = tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - - # Create masks - m_mask = offs_m < m_block_size - n_mask = offs_n < N - n_offset - k_mask = offs_k < K - k_offset - - # Combined masks - mk_mask = m_mask[:, None] & k_mask[None, :] - mn_mask = m_mask[:, None] & n_mask[None, :] - - # Global offsets for loading - m_global_offs = m_global_offset + offs_m - - # Load x block [M_chunk, K] - x_block = tl.load( - x_desc_ptr - + m_global_offs[:, None] * K - + (k_offset + offs_k)[None, :], - mask=mk_mask, - other=0.0, - ) - - # Load grad_output block [M_chunk, N] - grad_output_block = tl.load( - grad_output_desc_ptr - + m_global_offs[:, None] * N - + (n_offset + offs_n)[None, :], - mask=mn_mask, - other=0.0, - ) - - # Compute partial contribution: grad_W += grad_Y.T @ X - # transpose grad_output for the matmul - contribution = tl.dot( - grad_output_block.to(tl.float32).T, # [N, M_chunk] - x_block.to(tl.float32), # [M_chunk, K] - ) - - # Accumulate - accumulator += contribution - - # Store the result - if USE_TMA_STORE: - # Store using TMA - tl._experimental_descriptor_store( - workspace, # TMA store descriptor - accumulator.to(c_dtype), - [n_offset, k_offset], - ) - else: - # Manual store with bounds checking - offs_n = tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - - # Create masks for bounds checking - n_mask = offs_n < N - n_offset - k_mask = offs_k < K - k_offset - output_mask = n_mask[:, None] & k_mask[None, :] - - # Store the result - tl.store( - grad_weight_ptr - + (n_offset + offs_n)[:, None] * K - + (k_offset + offs_k)[None, :], - accumulator.to(c_dtype), - mask=output_mask, - ) + grad_weight_desc.store([n_offset, k_offset], accumulator.to(c_dtype)) +# ======== End Triton kernels ======== # ======== End Triton kernels ======== # ======== Triton wrapper functions ======== @@ -713,13 +340,13 @@ def grouped_gemm_forward( tma_size: int = 128, using_fp8: bool = False, ) -> torch.Tensor: - """ - M*G style grouped GEMM with TMA and Float8 support. - # Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors. - - """ + """Grouped GEMM forward using Hopper TMA kernels.""" if not CudaUtils.verify_tma(): raise NotImplementedError("Grouped GEMM without TMA is not supported yet") + if using_fp8: + raise NotImplementedError( + "FP8 path not implemented with the new Triton API yet" + ) G = m_sizes.shape[0] @@ -727,83 +354,27 @@ def grouped_gemm_forward( assert w.is_contiguous() assert m_sizes.is_contiguous() - # Total input size is now [M_total, K] where M_total is the sum of all group sizes M_total, K = x.shape - N = w.shape[0] # N is now the same for all groups - + N = w.shape[0] assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})" - # Verify that all group sizes are multiples of ALIGN_SIZE_M - # This check is commented out because it will involve a GPU-CPU sync - # assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M" - - # Create output tensor with correct shape [M_total, N] y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype) - if M_total == 0: return y NUM_SMS = CudaUtils.get_num_sms() - USE_TMA_LOAD = True - USE_TMA_STORE = True USE_EPILOGUE_SUBTILING = False - # TMA descriptor helper - desc_helper = None - desc_x = x - desc_w = w - workspace = None - - if USE_TMA_LOAD: - desc_helper = TmaDescriptorHelper(tma_size=tma_size) - desc_helper.init_tma_descriptor("x") - desc_helper.init_tma_descriptor("w") - desc_x = desc_helper.get_tma_descriptor_kernel_param("x") - desc_w = desc_helper.get_tma_descriptor_kernel_param("w") - - if USE_TMA_STORE: - if desc_helper is None: - raise RuntimeError( - "TMA descriptors must be initialized when USE_TMA_STORE is True" - ) - workspace = torch.empty( - NUM_SMS * desc_helper.tma_size, - device=x.device, - dtype=torch.uint8, - ) - - def grid(META): - if USE_TMA_LOAD: - nonlocal desc_helper - desc_helper.fill_2d_tma_descriptor( - "x", - x.data_ptr(), - M_total, - K, - META["BLOCK_SIZE_M"], - META["BLOCK_SIZE_K"], - x.element_size(), - ) - - desc_helper.fill_2d_tma_descriptor( - "w", - w.data_ptr(), - N, - K, - META["BLOCK_SIZE_N"], - META["BLOCK_SIZE_K"], - w.element_size(), - ) + def grid(_meta): return (NUM_SMS,) M_BUCKET = triton.next_power_of_2(M_total) - _kernel_mg_forward_hopper[grid]( - desc_x, - desc_w, + x, + w, y, - workspace, m_sizes, + M_total, G, M_BUCKET, N, @@ -812,7 +383,6 @@ def grouped_gemm_forward( TMA_SIZE=tma_size, USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING, ) - return y @@ -918,124 +488,44 @@ def grouped_gemm_dx_tma( num_sms: int = 132, tma_size: int = 128, ) -> torch.Tensor: - """ - Optimized backward pass wrapper for computing gradient with respect to input (dx) - using TMA patterns similar to the forward pass. - - Args: - grad_output: Gradient of output, shape [M_total, N] - w: Weight tensor, shape [N, K] - m_sizes: Group sizes tensor, shape [G] - tma_size: Size of TMA descriptor - # using_fp8: Whether to use FP8 quantization - # grad_output_scale: Scale for grad_output in FP8 mode - # w_scale: Scale for w in FP8 mode - - Returns: - grad_x: Gradient with respect to x, shape [M_total, K] - """ - """ - Optimized backward pass for computing gradient with respect to input (dx) - using TMA patterns similar to the forward pass. - - Args: - grad_output: Gradient of output, shape [M_total, N] - w: Weight tensor, shape [N, K] - m_sizes: Group sizes tensor, shape [G] - tma_size: Size of TMA descriptor - using_fp8: Whether to use FP8 quantization - # grad_output_scale: Scale for grad_output in FP8 mode - # w_scale: Scale for w in FP8 mode - - Returns: - grad_x: Gradient with respect to x, shape [M_total, K] - """ + """Compute grad_x using the Hopper grouped GEMM kernel.""" if not CudaUtils.verify_tma(): raise NotImplementedError("Optimized dx computation requires TMA support") - G = m_sizes.shape[0] + grad_output = grad_output.contiguous() + w = w.contiguous() + m_sizes = m_sizes.contiguous() - assert grad_output.is_contiguous() - assert w.is_contiguous() - assert m_sizes.is_contiguous() - - M_total, N_grad = grad_output.shape + M_total, N = grad_output.shape N_w, K = w.shape + if N != N_w: + raise ValueError(f"Grad_output N ({N}) must match weight N ({N_w})") - # Check dimensions - assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})" + if m_sizes.sum().item() != M_total: + raise ValueError("Sum of m_sizes must equal the number of rows in grad_output") - # Verify that the sum of m_sizes matches M_total - sum_m_sizes = m_sizes.sum().item() - assert M_total == sum_m_sizes, ( - f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" - ) - - # Create output tensor (grad_x) with shape [M_total, K] grad_x = torch.empty( (M_total, K), device=grad_output.device, dtype=grad_output.dtype ) - NUM_SMS = num_sms # CudaUtils.get_num_sms() - USE_TMA_LOAD = True - USE_TMA_STORE = True + NUM_SMS = num_sms - # Set up TMA descriptors - desc_helper = TmaDescriptorHelper(tma_size=tma_size) - desc_helper.init_tma_descriptor("grad_output") - desc_helper.init_tma_descriptor("w") - desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output") - desc_w = desc_helper.get_tma_descriptor_kernel_param("w") - - # Allocate workspace for TMA store - workspace = torch.empty( - NUM_SMS * desc_helper.tma_size, - device=grad_output.device, - dtype=torch.uint8, - ) - - def grid(META): - # Fill TMA descriptors with appropriate dimensions - desc_helper.fill_2d_tma_descriptor( - "grad_output", - grad_output.data_ptr(), - M_total, - N_grad, - META["BLOCK_SIZE_M"], - META["BLOCK_SIZE_N"], - grad_output.element_size(), - ) - - desc_helper.fill_2d_tma_descriptor( - "w", - w.data_ptr(), - N_w, - K, - META["BLOCK_SIZE_N"], - META["BLOCK_SIZE_K"], - w.element_size(), - ) + def grid(_meta): return (NUM_SMS,) M_BUCKET = triton.next_power_of_2(M_total) - - # Launch the flat linear kernel for computing grad_x _kernel_mg_dx_tma[grid]( - desc_grad_output, - desc_w, + grad_output, + w, grad_x, - workspace, m_sizes, - G, + M_total, + m_sizes.shape[0], M_BUCKET, - N_grad, # N dimension is now the reduction dimension + N, K, NUM_SMS, - USE_TMA_LOAD, - USE_TMA_STORE, - TMA_SIZE=tma_size, ) - return grad_x @@ -1049,139 +539,41 @@ def grouped_gemm_dw_tma( num_sms: int = 132, tma_size: int = 128, ) -> torch.Tensor: - """ - Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA. - For the forward pass Y = X @ W.T, the backward for weights is: - grad_W = grad_Y.T @ X - - Args: - x: Input tensor, shape [M_total, K] - grad_output: Gradient of output, shape [M_total, N] - m_sizes: Group sizes tensor, shape [G] - tma_size: Size of TMA descriptor in bytes - - - Returns: - grad_w: Gradient with respect to weights, shape [N, K] - """ - # Check TMA support + """Compute grad_w using the Hopper grouped GEMM kernel.""" if not CudaUtils.verify_tma(): raise RuntimeError("TMA grouped GEMM requested on a device without TMA support") - # Get group count - G = m_sizes.shape[0] - - # Ensure contiguous tensors x = x.contiguous() grad_output = grad_output.contiguous() m_sizes = m_sizes.contiguous() - # Get dimensions - M_total, K_x = x.shape + M_total, K = x.shape M_grad, N = grad_output.shape + if M_total != M_grad: + raise ValueError("x and grad_output must have matching batch dimension") + if m_sizes.sum().item() != M_total: + raise ValueError("Sum of m_sizes must equal the number of rows in the inputs") - # Check dimensions - assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})" - - # Verify that the sum of m_sizes matches M_total - sum_m_sizes = m_sizes.sum().item() - assert sum_m_sizes == M_total, ( - f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" - ) - - # Create output tensor (grad_w) with shape [N, K] - grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype) + grad_w = torch.zeros((N, K), device=x.device, dtype=x.dtype) NUM_SMS = num_sms - # TODO - hardcoded for now...but should set TMA flags based on hardware support - USE_TMA_LOAD = True - USE_TMA_STORE = True - - # Set up TMA descriptors or direct pointers - if USE_TMA_LOAD or USE_TMA_STORE: - desc_helper = TmaDescriptorHelper(tma_size=tma_size) - - if USE_TMA_LOAD: - desc_helper.init_tma_descriptor("x") - desc_helper.init_tma_descriptor("grad_output") - x_desc = desc_helper.get_tma_descriptor_kernel_param("x") - grad_output_desc = desc_helper.get_tma_descriptor_kernel_param( - "grad_output" - ) - else: - x_desc = x - grad_output_desc = grad_output - - if USE_TMA_STORE: - desc_helper.init_tma_descriptor("grad_w") - workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w") - else: - workspace = torch.empty(1, device=x.device, dtype=torch.uint8) - else: - # If not using TMA, just use the tensors directly - x_desc = x - grad_output_desc = grad_output - workspace = torch.empty(1, device=x.device, dtype=torch.uint8) - - # M_BUCKET for grid size - M_BUCKET = triton.next_power_of_2(M_total) - - # Define grid for kernel launch - def grid(META): - if USE_TMA_LOAD or USE_TMA_STORE: - if USE_TMA_LOAD: - desc_helper.fill_2d_tma_descriptor( - "x", - x.data_ptr(), - M_total, - K_x, - META["BLOCK_SIZE_M"], - META["BLOCK_SIZE_K"], - x.element_size(), - ) - - desc_helper.fill_2d_tma_descriptor( - "grad_output", - grad_output.data_ptr(), - M_total, - N, - META["BLOCK_SIZE_M"], - META["BLOCK_SIZE_N"], - grad_output.element_size(), - ) - - if USE_TMA_STORE: - desc_helper.fill_2d_tma_descriptor( - "grad_w", - grad_w.data_ptr(), - N, - K_x, - META["BLOCK_SIZE_N"], - META["BLOCK_SIZE_K"], - grad_w.element_size(), - ) - - # Return grid size - one block per SM for balanced work distribution + def grid(_meta): return (NUM_SMS,) - # Launch the optimized kernel + M_BUCKET = triton.next_power_of_2(M_total) _kernel_mg_dw_tma[grid]( - x_desc, - grad_output_desc, + x, + grad_output, grad_w, - workspace, m_sizes, - G, + M_total, + m_sizes.shape[0], M_BUCKET, N, - K_x, + K, NUM_SMS, - USE_TMA_LOAD, - USE_TMA_STORE, - TMA_SIZE=tma_size, ) - return grad_w @@ -1262,7 +654,7 @@ class GroupedGemmMg(torch.autograd.Function): ) # Return gradients for all inputs (None for non-differentiable parameters) - return grad_x, grad_w, None, None + return grad_x, grad_w, None, None, None, None def mg_grouped_gemm(