This commit is contained in:
Dan Saunders
2025-09-23 18:13:53 -04:00
parent 4db7a21ff7
commit d578c53603

View File

@@ -81,86 +81,92 @@ def _kernel_mg_forward_hopper(
m_size = tl.load(m_sizes + g) m_size = tl.load(m_sizes + g)
M_end = M_start + m_size M_end = M_start + m_size
if m_size <= 0: if m_size > 0:
continue num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
group_num_tiles = num_m_tiles * num_n_tiles
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) while (
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
group_num_tiles = num_m_tiles * num_n_tiles ):
group_index = tbidx - processed_tiles
while tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles: tile_m_index = group_index % num_m_tiles
group_index = tbidx - processed_tiles tile_n_index = group_index // num_m_tiles
tile_m_index = group_index % num_m_tiles rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
tile_n_index = group_index // num_m_tiles rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N
rows_remaining = tl.maximum(rows_remaining, 0) col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
global_n_offset = (g * n_size + n_offset).to(tl.int32)
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32) for k_offset in range(0, K, BLOCK_SIZE_K):
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) k_remaining = K - k_offset
global_n_offset = (g * n_size + n_offset).to(tl.int32) k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
for k_offset in range(0, K, BLOCK_SIZE_K): a = a_desc.load([m_offset, k_offset])
k_remaining = K - k_offset a_mask = row_mask[:, None] & k_mask[None, :]
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining a = tl.where(a_mask, a, tl.zeros_like(a))
a = a_desc.load([m_offset, k_offset]) b = b_desc.load([global_n_offset, k_offset])
a_mask = row_mask[:, None] & k_mask[None, :] b_mask = col_mask[:, None] & k_mask[None, :]
a = tl.where(a_mask, a, tl.zeros_like(a)) b = tl.where(b_mask, b, tl.zeros_like(b))
b = b_desc.load([global_n_offset, k_offset]) accumulator += tl.dot(a, b.T)
b_mask = col_mask[:, None] & k_mask[None, :]
b = tl.where(b_mask, b, tl.zeros_like(b))
accumulator += tl.dot(a, b.T) local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M)
row_store_mask = local_row_offsets < m_size
global_row = (M_start + local_row_offsets).to(tl.int32)
local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M) local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(
row_store_mask = local_row_offsets < m_size 0, BLOCK_SIZE_N
global_row = (M_start + local_row_offsets).to(tl.int32)
local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
col_store_mask = local_col_offsets < n_size
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
if USE_EPILOGUE_SUBTILING:
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
acc = tl.permute(acc, (0, 2, 1))
acc0, acc1 = tl.split(acc)
col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2]
col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2]
ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :]
tl.store(
ptr0,
acc0.to(c_dtype),
mask=row_store_mask[:, None] & col_mask0[None, :],
) )
col_store_mask = local_col_offsets < n_size
col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :] store_mask = row_store_mask[:, None] & col_store_mask[None, :]
col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :]
ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :]
tl.store(
ptr1,
acc1.to(c_dtype),
mask=row_store_mask[:, None] & col_mask1[None, :],
)
else:
ptr = c_ptr + global_row[:, None] * n_size + local_col_offsets[None, :]
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
tbidx += NUM_SMS if USE_EPILOGUE_SUBTILING:
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
acc = tl.permute(acc, (0, 2, 1))
acc0, acc1 = tl.split(acc)
processed_tiles += group_num_tiles col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2]
col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2]
ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :]
tl.store(
ptr0,
acc0.to(c_dtype),
mask=row_store_mask[:, None] & col_mask0[None, :],
)
col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :]
col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :]
ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :]
tl.store(
ptr1,
acc1.to(c_dtype),
mask=row_store_mask[:, None] & col_mask1[None, :],
)
else:
ptr = (
c_ptr
+ global_row[:, None] * n_size
+ local_col_offsets[None, :]
)
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
tbidx += NUM_SMS
processed_tiles += group_num_tiles
""" """
@@ -220,60 +226,62 @@ def _kernel_mg_dx_tma(
m_size = tl.load(m_sizes + g) m_size = tl.load(m_sizes + g)
M_end = M_start + m_size M_end = M_start + m_size
if m_size <= 0: if m_size > 0:
continue num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
group_num_tiles = num_m_tiles * num_k_tiles
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) while (
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
group_num_tiles = num_m_tiles * num_k_tiles ):
group_index = tbidx - processed_tiles
while tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles: tile_m_index = group_index % num_m_tiles
group_index = tbidx - processed_tiles tile_k_index = group_index // num_m_tiles
tile_m_index = group_index % num_m_tiles rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
tile_k_index = group_index // num_m_tiles rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M k_offset = tile_k_index * BLOCK_SIZE_K
rows_remaining = tl.maximum(rows_remaining, 0) k_remaining_total = K - k_offset
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total
k_offset = tile_k_index * BLOCK_SIZE_K accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
k_remaining_total = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32) for n_offset in range(0, N, BLOCK_SIZE_N):
n_remaining = N - n_offset
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
for n_offset in range(0, N, BLOCK_SIZE_N): grad_y = grad_output_desc.load([m_offset, n_offset])
n_remaining = N - n_offset grad_y_mask = row_mask[:, None] & n_mask[None, :]
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y))
grad_y = grad_output_desc.load([m_offset, n_offset]) w_tile = w_desc.load([n_offset, k_offset])
grad_y_mask = row_mask[:, None] & n_mask[None, :] w_mask = n_mask[:, None] & k_mask[None, :]
grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y)) w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile))
w_tile = w_desc.load([n_offset, k_offset]) accumulator += tl.dot(grad_y, w_tile)
w_mask = n_mask[:, None] & k_mask[None, :]
w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile))
accumulator += tl.dot(grad_y, w_tile) local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(
0, BLOCK_SIZE_M
)
row_store_mask = local_row_offsets < m_size
global_row = (M_start + local_row_offsets).to(tl.int32)
local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
row_store_mask = local_row_offsets < m_size col_store_mask = col_offsets < K
global_row = (M_start + local_row_offsets).to(tl.int32)
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K) store_mask = row_store_mask[:, None] & col_store_mask[None, :]
col_store_mask = col_offsets < K
store_mask = row_store_mask[:, None] & col_store_mask[None, :] ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :]
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :] tbidx += NUM_SMS
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
tbidx += NUM_SMS processed_tiles += group_num_tiles
processed_tiles += group_num_tiles
@triton.autotune( @triton.autotune(
@@ -342,29 +350,29 @@ def _kernel_mg_dw_tma(
m_size = tl.load(m_sizes + g) m_size = tl.load(m_sizes + g)
M_end = M_start + m_size M_end = M_start + m_size
if m_size <= 0: if m_size > 0:
continue for m_offset_local in range(0, m_size, BLOCK_SIZE_M):
rows_remaining = m_size - m_offset_local
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
for m_offset_local in range(0, m_size, BLOCK_SIZE_M): m_offset = (M_start + m_offset_local).to(tl.int32)
rows_remaining = m_size - m_offset_local
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
m_offset = (M_start + m_offset_local).to(tl.int32) x_block = x_desc.load([m_offset, k_offset])
x_mask = row_mask[:, None] & k_mask[None, :]
x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block))
x_block = x_desc.load([m_offset, k_offset]) grad_block = grad_output_desc.load([m_offset, n_offset])
x_mask = row_mask[:, None] & k_mask[None, :] grad_mask = row_mask[:, None] & n_mask[None, :]
x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block)) grad_block = tl.where(
grad_mask, grad_block, tl.zeros_like(grad_block)
)
grad_block = grad_output_desc.load([m_offset, n_offset]) contribution = tl.dot(
grad_mask = row_mask[:, None] & n_mask[None, :] grad_block.to(tl.float32).T,
grad_block = tl.where(grad_mask, grad_block, tl.zeros_like(grad_block)) x_block.to(tl.float32),
)
contribution = tl.dot( accumulator += contribution
grad_block.to(tl.float32).T,
x_block.to(tl.float32),
)
accumulator += contribution
row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N) row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N)
row_store_mask = row_offsets < N row_store_mask = row_offsets < N