fix
This commit is contained in:
@@ -81,86 +81,92 @@ def _kernel_mg_forward_hopper(
|
|||||||
m_size = tl.load(m_sizes + g)
|
m_size = tl.load(m_sizes + g)
|
||||||
M_end = M_start + m_size
|
M_end = M_start + m_size
|
||||||
|
|
||||||
if m_size <= 0:
|
if m_size > 0:
|
||||||
continue
|
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||||
|
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
||||||
|
group_num_tiles = num_m_tiles * num_n_tiles
|
||||||
|
|
||||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
while (
|
||||||
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
|
||||||
group_num_tiles = num_m_tiles * num_n_tiles
|
):
|
||||||
|
group_index = tbidx - processed_tiles
|
||||||
|
|
||||||
while tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles:
|
tile_m_index = group_index % num_m_tiles
|
||||||
group_index = tbidx - processed_tiles
|
tile_n_index = group_index // num_m_tiles
|
||||||
|
|
||||||
tile_m_index = group_index % num_m_tiles
|
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
|
||||||
tile_n_index = group_index // num_m_tiles
|
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||||
|
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||||
|
|
||||||
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
|
cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N
|
||||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining
|
||||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
|
||||||
|
|
||||||
cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining
|
|
||||||
|
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||||
|
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
||||||
|
global_n_offset = (g * n_size + n_offset).to(tl.int32)
|
||||||
|
|
||||||
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
for k_offset in range(0, K, BLOCK_SIZE_K):
|
||||||
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
k_remaining = K - k_offset
|
||||||
global_n_offset = (g * n_size + n_offset).to(tl.int32)
|
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
|
||||||
|
|
||||||
for k_offset in range(0, K, BLOCK_SIZE_K):
|
a = a_desc.load([m_offset, k_offset])
|
||||||
k_remaining = K - k_offset
|
a_mask = row_mask[:, None] & k_mask[None, :]
|
||||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
|
a = tl.where(a_mask, a, tl.zeros_like(a))
|
||||||
|
|
||||||
a = a_desc.load([m_offset, k_offset])
|
b = b_desc.load([global_n_offset, k_offset])
|
||||||
a_mask = row_mask[:, None] & k_mask[None, :]
|
b_mask = col_mask[:, None] & k_mask[None, :]
|
||||||
a = tl.where(a_mask, a, tl.zeros_like(a))
|
b = tl.where(b_mask, b, tl.zeros_like(b))
|
||||||
|
|
||||||
b = b_desc.load([global_n_offset, k_offset])
|
accumulator += tl.dot(a, b.T)
|
||||||
b_mask = col_mask[:, None] & k_mask[None, :]
|
|
||||||
b = tl.where(b_mask, b, tl.zeros_like(b))
|
|
||||||
|
|
||||||
accumulator += tl.dot(a, b.T)
|
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||||
|
|
||||||
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
row_store_mask = local_row_offsets < m_size
|
||||||
|
global_row = (M_start + local_row_offsets).to(tl.int32)
|
||||||
|
|
||||||
local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M)
|
local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(
|
||||||
row_store_mask = local_row_offsets < m_size
|
0, BLOCK_SIZE_N
|
||||||
global_row = (M_start + local_row_offsets).to(tl.int32)
|
|
||||||
|
|
||||||
local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
||||||
col_store_mask = local_col_offsets < n_size
|
|
||||||
|
|
||||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
|
||||||
|
|
||||||
if USE_EPILOGUE_SUBTILING:
|
|
||||||
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
|
|
||||||
acc = tl.permute(acc, (0, 2, 1))
|
|
||||||
acc0, acc1 = tl.split(acc)
|
|
||||||
|
|
||||||
col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2]
|
|
||||||
col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2]
|
|
||||||
ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :]
|
|
||||||
tl.store(
|
|
||||||
ptr0,
|
|
||||||
acc0.to(c_dtype),
|
|
||||||
mask=row_store_mask[:, None] & col_mask0[None, :],
|
|
||||||
)
|
)
|
||||||
|
col_store_mask = local_col_offsets < n_size
|
||||||
|
|
||||||
col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :]
|
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||||
col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :]
|
|
||||||
ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :]
|
|
||||||
tl.store(
|
|
||||||
ptr1,
|
|
||||||
acc1.to(c_dtype),
|
|
||||||
mask=row_store_mask[:, None] & col_mask1[None, :],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ptr = c_ptr + global_row[:, None] * n_size + local_col_offsets[None, :]
|
|
||||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
|
||||||
|
|
||||||
tbidx += NUM_SMS
|
if USE_EPILOGUE_SUBTILING:
|
||||||
|
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
|
||||||
|
acc = tl.permute(acc, (0, 2, 1))
|
||||||
|
acc0, acc1 = tl.split(acc)
|
||||||
|
|
||||||
processed_tiles += group_num_tiles
|
col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2]
|
||||||
|
col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2]
|
||||||
|
ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :]
|
||||||
|
tl.store(
|
||||||
|
ptr0,
|
||||||
|
acc0.to(c_dtype),
|
||||||
|
mask=row_store_mask[:, None] & col_mask0[None, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :]
|
||||||
|
col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :]
|
||||||
|
ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :]
|
||||||
|
tl.store(
|
||||||
|
ptr1,
|
||||||
|
acc1.to(c_dtype),
|
||||||
|
mask=row_store_mask[:, None] & col_mask1[None, :],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ptr = (
|
||||||
|
c_ptr
|
||||||
|
+ global_row[:, None] * n_size
|
||||||
|
+ local_col_offsets[None, :]
|
||||||
|
)
|
||||||
|
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||||
|
|
||||||
|
tbidx += NUM_SMS
|
||||||
|
|
||||||
|
processed_tiles += group_num_tiles
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -220,60 +226,62 @@ def _kernel_mg_dx_tma(
|
|||||||
m_size = tl.load(m_sizes + g)
|
m_size = tl.load(m_sizes + g)
|
||||||
M_end = M_start + m_size
|
M_end = M_start + m_size
|
||||||
|
|
||||||
if m_size <= 0:
|
if m_size > 0:
|
||||||
continue
|
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||||
|
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||||
|
group_num_tiles = num_m_tiles * num_k_tiles
|
||||||
|
|
||||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
while (
|
||||||
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
|
||||||
group_num_tiles = num_m_tiles * num_k_tiles
|
):
|
||||||
|
group_index = tbidx - processed_tiles
|
||||||
|
|
||||||
while tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles:
|
tile_m_index = group_index % num_m_tiles
|
||||||
group_index = tbidx - processed_tiles
|
tile_k_index = group_index // num_m_tiles
|
||||||
|
|
||||||
tile_m_index = group_index % num_m_tiles
|
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
|
||||||
tile_k_index = group_index // num_m_tiles
|
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||||
|
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||||
|
|
||||||
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
|
k_offset = tile_k_index * BLOCK_SIZE_K
|
||||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
k_remaining_total = K - k_offset
|
||||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total
|
||||||
|
|
||||||
k_offset = tile_k_index * BLOCK_SIZE_K
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
||||||
k_remaining_total = K - k_offset
|
|
||||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total
|
|
||||||
|
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||||
|
|
||||||
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
for n_offset in range(0, N, BLOCK_SIZE_N):
|
||||||
|
n_remaining = N - n_offset
|
||||||
|
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
|
||||||
|
|
||||||
for n_offset in range(0, N, BLOCK_SIZE_N):
|
grad_y = grad_output_desc.load([m_offset, n_offset])
|
||||||
n_remaining = N - n_offset
|
grad_y_mask = row_mask[:, None] & n_mask[None, :]
|
||||||
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
|
grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y))
|
||||||
|
|
||||||
grad_y = grad_output_desc.load([m_offset, n_offset])
|
w_tile = w_desc.load([n_offset, k_offset])
|
||||||
grad_y_mask = row_mask[:, None] & n_mask[None, :]
|
w_mask = n_mask[:, None] & k_mask[None, :]
|
||||||
grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y))
|
w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile))
|
||||||
|
|
||||||
w_tile = w_desc.load([n_offset, k_offset])
|
accumulator += tl.dot(grad_y, w_tile)
|
||||||
w_mask = n_mask[:, None] & k_mask[None, :]
|
|
||||||
w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile))
|
|
||||||
|
|
||||||
accumulator += tl.dot(grad_y, w_tile)
|
local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(
|
||||||
|
0, BLOCK_SIZE_M
|
||||||
|
)
|
||||||
|
row_store_mask = local_row_offsets < m_size
|
||||||
|
global_row = (M_start + local_row_offsets).to(tl.int32)
|
||||||
|
|
||||||
local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
|
||||||
row_store_mask = local_row_offsets < m_size
|
col_store_mask = col_offsets < K
|
||||||
global_row = (M_start + local_row_offsets).to(tl.int32)
|
|
||||||
|
|
||||||
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
|
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||||
col_store_mask = col_offsets < K
|
|
||||||
|
|
||||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :]
|
||||||
|
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||||
|
|
||||||
ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :]
|
tbidx += NUM_SMS
|
||||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
|
||||||
|
|
||||||
tbidx += NUM_SMS
|
processed_tiles += group_num_tiles
|
||||||
|
|
||||||
processed_tiles += group_num_tiles
|
|
||||||
|
|
||||||
|
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
@@ -342,29 +350,29 @@ def _kernel_mg_dw_tma(
|
|||||||
m_size = tl.load(m_sizes + g)
|
m_size = tl.load(m_sizes + g)
|
||||||
M_end = M_start + m_size
|
M_end = M_start + m_size
|
||||||
|
|
||||||
if m_size <= 0:
|
if m_size > 0:
|
||||||
continue
|
for m_offset_local in range(0, m_size, BLOCK_SIZE_M):
|
||||||
|
rows_remaining = m_size - m_offset_local
|
||||||
|
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||||
|
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||||
|
|
||||||
for m_offset_local in range(0, m_size, BLOCK_SIZE_M):
|
m_offset = (M_start + m_offset_local).to(tl.int32)
|
||||||
rows_remaining = m_size - m_offset_local
|
|
||||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
|
||||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
|
||||||
|
|
||||||
m_offset = (M_start + m_offset_local).to(tl.int32)
|
x_block = x_desc.load([m_offset, k_offset])
|
||||||
|
x_mask = row_mask[:, None] & k_mask[None, :]
|
||||||
|
x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block))
|
||||||
|
|
||||||
x_block = x_desc.load([m_offset, k_offset])
|
grad_block = grad_output_desc.load([m_offset, n_offset])
|
||||||
x_mask = row_mask[:, None] & k_mask[None, :]
|
grad_mask = row_mask[:, None] & n_mask[None, :]
|
||||||
x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block))
|
grad_block = tl.where(
|
||||||
|
grad_mask, grad_block, tl.zeros_like(grad_block)
|
||||||
|
)
|
||||||
|
|
||||||
grad_block = grad_output_desc.load([m_offset, n_offset])
|
contribution = tl.dot(
|
||||||
grad_mask = row_mask[:, None] & n_mask[None, :]
|
grad_block.to(tl.float32).T,
|
||||||
grad_block = tl.where(grad_mask, grad_block, tl.zeros_like(grad_block))
|
x_block.to(tl.float32),
|
||||||
|
)
|
||||||
contribution = tl.dot(
|
accumulator += contribution
|
||||||
grad_block.to(tl.float32).T,
|
|
||||||
x_block.to(tl.float32),
|
|
||||||
)
|
|
||||||
accumulator += contribution
|
|
||||||
|
|
||||||
row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N)
|
row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N)
|
||||||
row_store_mask = row_offsets < N
|
row_store_mask = row_offsets < N
|
||||||
|
|||||||
Reference in New Issue
Block a user