diff --git a/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py b/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py index 9e540701e..6149dfc72 100644 --- a/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py +++ b/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py @@ -81,86 +81,92 @@ def _kernel_mg_forward_hopper( m_size = tl.load(m_sizes + g) M_end = M_start + m_size - if m_size <= 0: - continue + if m_size > 0: + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles - num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) - num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) - group_num_tiles = num_m_tiles * num_n_tiles + while ( + tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles - while tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles: - group_index = tbidx - processed_tiles + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles - tile_m_index = group_index % num_m_tiles - tile_n_index = group_index // num_m_tiles + rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M + rows_remaining = tl.maximum(rows_remaining, 0) + row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining - rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M - rows_remaining = tl.maximum(rows_remaining, 0) - row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining + cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N + col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining - cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N - col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + global_n_offset = (g * n_size + n_offset).to(tl.int32) - m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32) - n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) - global_n_offset = (g * n_size + n_offset).to(tl.int32) + for k_offset in range(0, K, BLOCK_SIZE_K): + k_remaining = K - k_offset + k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining - for k_offset in range(0, K, BLOCK_SIZE_K): - k_remaining = K - k_offset - k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining + a = a_desc.load([m_offset, k_offset]) + a_mask = row_mask[:, None] & k_mask[None, :] + a = tl.where(a_mask, a, tl.zeros_like(a)) - a = a_desc.load([m_offset, k_offset]) - a_mask = row_mask[:, None] & k_mask[None, :] - a = tl.where(a_mask, a, tl.zeros_like(a)) + b = b_desc.load([global_n_offset, k_offset]) + b_mask = col_mask[:, None] & k_mask[None, :] + b = tl.where(b_mask, b, tl.zeros_like(b)) - b = b_desc.load([global_n_offset, k_offset]) - b_mask = col_mask[:, None] & k_mask[None, :] - b = tl.where(b_mask, b, tl.zeros_like(b)) + accumulator += tl.dot(a, b.T) - accumulator += tl.dot(a, b.T) + local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) - local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M) + row_store_mask = local_row_offsets < m_size + global_row = (M_start + local_row_offsets).to(tl.int32) - local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M) - row_store_mask = local_row_offsets < m_size - global_row = (M_start + local_row_offsets).to(tl.int32) - - local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - col_store_mask = local_col_offsets < n_size - - store_mask = row_store_mask[:, None] & col_store_mask[None, :] - - if USE_EPILOGUE_SUBTILING: - acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) - acc = tl.permute(acc, (0, 2, 1)) - acc0, acc1 = tl.split(acc) - - col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2] - col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2] - ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :] - tl.store( - ptr0, - acc0.to(c_dtype), - mask=row_store_mask[:, None] & col_mask0[None, :], + local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N ) + col_store_mask = local_col_offsets < n_size - col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :] - col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :] - ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :] - tl.store( - ptr1, - acc1.to(c_dtype), - mask=row_store_mask[:, None] & col_mask1[None, :], - ) - else: - ptr = c_ptr + global_row[:, None] * n_size + local_col_offsets[None, :] - tl.store(ptr, accumulator.to(c_dtype), mask=store_mask) + store_mask = row_store_mask[:, None] & col_store_mask[None, :] - tbidx += NUM_SMS + if USE_EPILOGUE_SUBTILING: + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) - processed_tiles += group_num_tiles + col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2] + col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2] + ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :] + tl.store( + ptr0, + acc0.to(c_dtype), + mask=row_store_mask[:, None] & col_mask0[None, :], + ) + + col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :] + col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :] + ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :] + tl.store( + ptr1, + acc1.to(c_dtype), + mask=row_store_mask[:, None] & col_mask1[None, :], + ) + else: + ptr = ( + c_ptr + + global_row[:, None] * n_size + + local_col_offsets[None, :] + ) + tl.store(ptr, accumulator.to(c_dtype), mask=store_mask) + + tbidx += NUM_SMS + + processed_tiles += group_num_tiles """ @@ -220,60 +226,62 @@ def _kernel_mg_dx_tma( m_size = tl.load(m_sizes + g) M_end = M_start + m_size - if m_size <= 0: - continue + if m_size > 0: + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + group_num_tiles = num_m_tiles * num_k_tiles - num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) - num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) - group_num_tiles = num_m_tiles * num_k_tiles + while ( + tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles - while tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles: - group_index = tbidx - processed_tiles + tile_m_index = group_index % num_m_tiles + tile_k_index = group_index // num_m_tiles - tile_m_index = group_index % num_m_tiles - tile_k_index = group_index // num_m_tiles + rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M + rows_remaining = tl.maximum(rows_remaining, 0) + row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining - rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M - rows_remaining = tl.maximum(rows_remaining, 0) - row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining + k_offset = tile_k_index * BLOCK_SIZE_K + k_remaining_total = K - k_offset + k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total - k_offset = tile_k_index * BLOCK_SIZE_K - k_remaining_total = K - k_offset - k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32) - m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32) + for n_offset in range(0, N, BLOCK_SIZE_N): + n_remaining = N - n_offset + n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining - for n_offset in range(0, N, BLOCK_SIZE_N): - n_remaining = N - n_offset - n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining + grad_y = grad_output_desc.load([m_offset, n_offset]) + grad_y_mask = row_mask[:, None] & n_mask[None, :] + grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y)) - grad_y = grad_output_desc.load([m_offset, n_offset]) - grad_y_mask = row_mask[:, None] & n_mask[None, :] - grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y)) + w_tile = w_desc.load([n_offset, k_offset]) + w_mask = n_mask[:, None] & k_mask[None, :] + w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile)) - w_tile = w_desc.load([n_offset, k_offset]) - w_mask = n_mask[:, None] & k_mask[None, :] - w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile)) + accumulator += tl.dot(grad_y, w_tile) - accumulator += tl.dot(grad_y, w_tile) + local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange( + 0, BLOCK_SIZE_M + ) + row_store_mask = local_row_offsets < m_size + global_row = (M_start + local_row_offsets).to(tl.int32) - local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - row_store_mask = local_row_offsets < m_size - global_row = (M_start + local_row_offsets).to(tl.int32) + col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K) + col_store_mask = col_offsets < K - col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K) - col_store_mask = col_offsets < K + store_mask = row_store_mask[:, None] & col_store_mask[None, :] - store_mask = row_store_mask[:, None] & col_store_mask[None, :] + ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :] + tl.store(ptr, accumulator.to(c_dtype), mask=store_mask) - ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :] - tl.store(ptr, accumulator.to(c_dtype), mask=store_mask) + tbidx += NUM_SMS - tbidx += NUM_SMS - - processed_tiles += group_num_tiles + processed_tiles += group_num_tiles @triton.autotune( @@ -342,29 +350,29 @@ def _kernel_mg_dw_tma( m_size = tl.load(m_sizes + g) M_end = M_start + m_size - if m_size <= 0: - continue + if m_size > 0: + for m_offset_local in range(0, m_size, BLOCK_SIZE_M): + rows_remaining = m_size - m_offset_local + rows_remaining = tl.maximum(rows_remaining, 0) + row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining - for m_offset_local in range(0, m_size, BLOCK_SIZE_M): - rows_remaining = m_size - m_offset_local - rows_remaining = tl.maximum(rows_remaining, 0) - row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining + m_offset = (M_start + m_offset_local).to(tl.int32) - m_offset = (M_start + m_offset_local).to(tl.int32) + x_block = x_desc.load([m_offset, k_offset]) + x_mask = row_mask[:, None] & k_mask[None, :] + x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block)) - x_block = x_desc.load([m_offset, k_offset]) - x_mask = row_mask[:, None] & k_mask[None, :] - x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block)) + grad_block = grad_output_desc.load([m_offset, n_offset]) + grad_mask = row_mask[:, None] & n_mask[None, :] + grad_block = tl.where( + grad_mask, grad_block, tl.zeros_like(grad_block) + ) - grad_block = grad_output_desc.load([m_offset, n_offset]) - grad_mask = row_mask[:, None] & n_mask[None, :] - grad_block = tl.where(grad_mask, grad_block, tl.zeros_like(grad_block)) - - contribution = tl.dot( - grad_block.to(tl.float32).T, - x_block.to(tl.float32), - ) - accumulator += contribution + contribution = tl.dot( + grad_block.to(tl.float32).T, + x_block.to(tl.float32), + ) + accumulator += contribution row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N) row_store_mask = row_offsets < N