This commit is contained in:
Dan Saunders
2025-09-23 18:03:41 -04:00
parent 3b2e05c563
commit 4db7a21ff7

View File

@@ -48,7 +48,6 @@ def _kernel_mg_forward_hopper(
K: tl.constexpr,
# config
NUM_SMS: tl.constexpr,
TMA_SIZE: tl.constexpr,
USE_EPILOGUE_SUBTILING: tl.constexpr,
# tiles
BLOCK_SIZE_M: tl.constexpr,
@@ -56,27 +55,21 @@ def _kernel_mg_forward_hopper(
BLOCK_SIZE_K: tl.constexpr,
) -> None:
"""Flat index style forward kernel for Hopper using tensor descriptors."""
tbidx = tl.program_id(0) # thread block index
tbidx = tl.program_id(0)
c_dtype = c_ptr.dtype.element_ty # output dtype
c_dtype = c_ptr.dtype.element_ty
n_size = N // G
# Compute the total rows spanned by all groups so descriptors cover valid ranges.
total_rows = M_TOTAL.to(tl.int32)
stride_k = tl.full([], K, dtype=tl.int32)
stride_1 = tl.full([], 1, dtype=tl.int32)
a_desc = tl.make_tensor_descriptor(
a_ptr,
shape=[total_rows, stride_k],
strides=[stride_k, stride_1],
shape=[M_TOTAL, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
b_desc = tl.make_tensor_descriptor(
b_ptr,
shape=[tl.full([], N, dtype=tl.int32), stride_k],
strides=[stride_k, stride_1],
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
@@ -88,55 +81,86 @@ def _kernel_mg_forward_hopper(
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
group_base = c_ptr + M_start * n_size
c_desc = tl.make_tensor_descriptor(
group_base,
shape=[m_size, tl.full([], n_size, dtype=tl.int32)],
strides=[tl.full([], n_size, dtype=tl.int32), stride_1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
if m_size <= 0:
continue
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
group_num_tiles = num_m_tiles * num_n_tiles
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
group_num_tiles = num_m_tiles * num_n_tiles
while (
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
):
group_index = tbidx - processed_tiles
while tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles:
group_index = tbidx - processed_tiles
tile_m_index = group_index % num_m_tiles
tile_n_index = group_index // num_m_tiles
tile_m_index = group_index % num_m_tiles
tile_n_index = group_index // num_m_tiles
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
global_n_offset = (g * n_size + n_offset).to(tl.int32)
cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N
col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining
for k_offset in range(0, K, BLOCK_SIZE_K):
a = a_desc.load([m_offset, k_offset])
b = b_desc.load([global_n_offset, k_offset])
accumulator += tl.dot(a, b.T)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
global_n_offset = (g * n_size + n_offset).to(tl.int32)
if USE_EPILOGUE_SUBTILING:
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
acc = tl.permute(acc, (0, 2, 1))
acc0, acc1 = tl.split(acc)
c_desc.store([local_m_offset, n_offset], acc0.to(c_dtype))
c_desc.store(
[local_m_offset, n_offset + BLOCK_SIZE_N // 2],
acc1.to(c_dtype),
)
else:
c_desc.store([local_m_offset, n_offset], accumulator.to(c_dtype))
for k_offset in range(0, K, BLOCK_SIZE_K):
k_remaining = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
tbidx += NUM_SMS
a = a_desc.load([m_offset, k_offset])
a_mask = row_mask[:, None] & k_mask[None, :]
a = tl.where(a_mask, a, tl.zeros_like(a))
processed_tiles += group_num_tiles
b = b_desc.load([global_n_offset, k_offset])
b_mask = col_mask[:, None] & k_mask[None, :]
b = tl.where(b_mask, b, tl.zeros_like(b))
accumulator += tl.dot(a, b.T)
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M)
row_store_mask = local_row_offsets < m_size
global_row = (M_start + local_row_offsets).to(tl.int32)
local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
col_store_mask = local_col_offsets < n_size
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
if USE_EPILOGUE_SUBTILING:
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
acc = tl.permute(acc, (0, 2, 1))
acc0, acc1 = tl.split(acc)
col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2]
col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2]
ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :]
tl.store(
ptr0,
acc0.to(c_dtype),
mask=row_store_mask[:, None] & col_mask0[None, :],
)
col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :]
col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :]
ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :]
tl.store(
ptr1,
acc1.to(c_dtype),
mask=row_store_mask[:, None] & col_mask1[None, :],
)
else:
ptr = c_ptr + global_row[:, None] * n_size + local_col_offsets[None, :]
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
tbidx += NUM_SMS
processed_tiles += group_num_tiles
"""
@@ -174,22 +198,17 @@ def _kernel_mg_dx_tma(
tbidx = tl.program_id(0)
c_dtype = grad_input_ptr.dtype.element_ty
stride_1 = tl.full([], 1, dtype=tl.int32)
stride_n = tl.full([], N, dtype=tl.int32)
stride_k = tl.full([], K, dtype=tl.int32)
total_rows = M_TOTAL.to(tl.int32)
grad_output_desc = tl.make_tensor_descriptor(
grad_output_ptr,
shape=[total_rows, stride_n],
strides=[stride_n, stride_1],
shape=[M_TOTAL, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
w_desc = tl.make_tensor_descriptor(
w_ptr,
shape=[stride_n, stride_k],
strides=[stride_k, stride_1],
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
@@ -201,43 +220,60 @@ def _kernel_mg_dx_tma(
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
grad_input_desc = tl.make_tensor_descriptor(
grad_input_ptr + M_start * K,
shape=[m_size, stride_k],
strides=[stride_k, stride_1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
if m_size <= 0:
continue
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
group_num_tiles = num_m_tiles * num_k_tiles
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
group_num_tiles = num_m_tiles * num_k_tiles
while (
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
):
group_index = tbidx - processed_tiles
tile_m_index = group_index % num_m_tiles
tile_k_index = group_index // num_m_tiles
while tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles:
group_index = tbidx - processed_tiles
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
tile_m_index = group_index % num_m_tiles
tile_k_index = group_index // num_m_tiles
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
for n_offset in range(0, N, BLOCK_SIZE_N):
grad_y = grad_output_desc.load([m_offset, n_offset])
w_tile = w_desc.load([n_offset, k_offset])
accumulator += tl.dot(grad_y, w_tile)
k_offset = tile_k_index * BLOCK_SIZE_K
k_remaining_total = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
grad_input_desc.store(
[local_m_offset, k_offset], accumulator.to(c_dtype)
)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
tbidx += NUM_SMS
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
processed_tiles += group_num_tiles
for n_offset in range(0, N, BLOCK_SIZE_N):
n_remaining = N - n_offset
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
grad_y = grad_output_desc.load([m_offset, n_offset])
grad_y_mask = row_mask[:, None] & n_mask[None, :]
grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y))
w_tile = w_desc.load([n_offset, k_offset])
w_mask = n_mask[:, None] & k_mask[None, :]
w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile))
accumulator += tl.dot(grad_y, w_tile)
local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
row_store_mask = local_row_offsets < m_size
global_row = (M_start + local_row_offsets).to(tl.int32)
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
col_store_mask = col_offsets < K
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :]
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
tbidx += NUM_SMS
processed_tiles += group_num_tiles
@triton.autotune(
@@ -268,30 +304,19 @@ def _kernel_mg_dw_tma(
tbidx = tl.program_id(0)
c_dtype = grad_weight_ptr.dtype.element_ty
stride_1 = tl.full([], 1, dtype=tl.int32)
stride_n = tl.full([], N, dtype=tl.int32)
stride_k = tl.full([], K, dtype=tl.int32)
total_rows = M_TOTAL.to(tl.int32)
x_desc = tl.make_tensor_descriptor(
x_ptr,
shape=[total_rows, stride_k],
strides=[stride_k, stride_1],
shape=[M_TOTAL, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
grad_output_desc = tl.make_tensor_descriptor(
grad_output_ptr,
shape=[total_rows, stride_n],
strides=[stride_n, stride_1],
shape=[M_TOTAL, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
grad_weight_desc = tl.make_tensor_descriptor(
grad_weight_ptr,
shape=[stride_n, stride_k],
strides=[stride_k, stride_1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
@@ -301,8 +326,13 @@ def _kernel_mg_dw_tma(
tile_n_idx = tile_idx % num_n_tiles
tile_k_idx = tile_idx // num_n_tiles
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
k_offset = (tile_k_idx * BLOCK_SIZE_K).to(tl.int32)
n_offset = tile_n_idx * BLOCK_SIZE_N
n_remaining = N - n_offset
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
k_offset = tile_k_idx * BLOCK_SIZE_K
k_remaining = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
@@ -312,17 +342,40 @@ def _kernel_mg_dw_tma(
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
for m_offset in range(0, m_size, BLOCK_SIZE_M):
m_global = (M_start + m_offset).to(tl.int32)
grad_block = grad_output_desc.load([m_global, n_offset])
x_block = x_desc.load([m_global, k_offset])
accumulator += tl.dot(
grad_block.to(tl.float32).T,
x_block.to(tl.float32),
)
if m_size <= 0:
continue
grad_weight_desc.store([n_offset, k_offset], accumulator.to(c_dtype))
for m_offset_local in range(0, m_size, BLOCK_SIZE_M):
rows_remaining = m_size - m_offset_local
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
m_offset = (M_start + m_offset_local).to(tl.int32)
x_block = x_desc.load([m_offset, k_offset])
x_mask = row_mask[:, None] & k_mask[None, :]
x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block))
grad_block = grad_output_desc.load([m_offset, n_offset])
grad_mask = row_mask[:, None] & n_mask[None, :]
grad_block = tl.where(grad_mask, grad_block, tl.zeros_like(grad_block))
contribution = tl.dot(
grad_block.to(tl.float32).T,
x_block.to(tl.float32),
)
accumulator += contribution
row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N)
row_store_mask = row_offsets < N
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
col_store_mask = col_offsets < K
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
ptr = grad_weight_ptr + row_offsets[:, None] * K + col_offsets[None, :]
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
# ======== End Triton kernels ========
@@ -380,7 +433,6 @@ def grouped_gemm_forward(
N,
K,
NUM_SMS,
TMA_SIZE=tma_size,
USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
)
return y