fix
This commit is contained in:
@@ -81,86 +81,92 @@ def _kernel_mg_forward_hopper(
|
||||
m_size = tl.load(m_sizes + g)
|
||||
M_end = M_start + m_size
|
||||
|
||||
if m_size <= 0:
|
||||
continue
|
||||
if m_size > 0:
|
||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
||||
group_num_tiles = num_m_tiles * num_n_tiles
|
||||
|
||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
||||
group_num_tiles = num_m_tiles * num_n_tiles
|
||||
while (
|
||||
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
|
||||
):
|
||||
group_index = tbidx - processed_tiles
|
||||
|
||||
while tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles:
|
||||
group_index = tbidx - processed_tiles
|
||||
tile_m_index = group_index % num_m_tiles
|
||||
tile_n_index = group_index // num_m_tiles
|
||||
|
||||
tile_m_index = group_index % num_m_tiles
|
||||
tile_n_index = group_index // num_m_tiles
|
||||
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
|
||||
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N
|
||||
col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining
|
||||
|
||||
cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N
|
||||
col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
||||
global_n_offset = (g * n_size + n_offset).to(tl.int32)
|
||||
|
||||
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
||||
global_n_offset = (g * n_size + n_offset).to(tl.int32)
|
||||
for k_offset in range(0, K, BLOCK_SIZE_K):
|
||||
k_remaining = K - k_offset
|
||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
|
||||
|
||||
for k_offset in range(0, K, BLOCK_SIZE_K):
|
||||
k_remaining = K - k_offset
|
||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
|
||||
a = a_desc.load([m_offset, k_offset])
|
||||
a_mask = row_mask[:, None] & k_mask[None, :]
|
||||
a = tl.where(a_mask, a, tl.zeros_like(a))
|
||||
|
||||
a = a_desc.load([m_offset, k_offset])
|
||||
a_mask = row_mask[:, None] & k_mask[None, :]
|
||||
a = tl.where(a_mask, a, tl.zeros_like(a))
|
||||
b = b_desc.load([global_n_offset, k_offset])
|
||||
b_mask = col_mask[:, None] & k_mask[None, :]
|
||||
b = tl.where(b_mask, b, tl.zeros_like(b))
|
||||
|
||||
b = b_desc.load([global_n_offset, k_offset])
|
||||
b_mask = col_mask[:, None] & k_mask[None, :]
|
||||
b = tl.where(b_mask, b, tl.zeros_like(b))
|
||||
accumulator += tl.dot(a, b.T)
|
||||
|
||||
accumulator += tl.dot(a, b.T)
|
||||
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
|
||||
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M)
|
||||
row_store_mask = local_row_offsets < m_size
|
||||
global_row = (M_start + local_row_offsets).to(tl.int32)
|
||||
|
||||
local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M)
|
||||
row_store_mask = local_row_offsets < m_size
|
||||
global_row = (M_start + local_row_offsets).to(tl.int32)
|
||||
|
||||
local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
col_store_mask = local_col_offsets < n_size
|
||||
|
||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||
|
||||
if USE_EPILOGUE_SUBTILING:
|
||||
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
|
||||
acc = tl.permute(acc, (0, 2, 1))
|
||||
acc0, acc1 = tl.split(acc)
|
||||
|
||||
col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2]
|
||||
col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2]
|
||||
ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :]
|
||||
tl.store(
|
||||
ptr0,
|
||||
acc0.to(c_dtype),
|
||||
mask=row_store_mask[:, None] & col_mask0[None, :],
|
||||
local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(
|
||||
0, BLOCK_SIZE_N
|
||||
)
|
||||
col_store_mask = local_col_offsets < n_size
|
||||
|
||||
col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :]
|
||||
col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :]
|
||||
ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :]
|
||||
tl.store(
|
||||
ptr1,
|
||||
acc1.to(c_dtype),
|
||||
mask=row_store_mask[:, None] & col_mask1[None, :],
|
||||
)
|
||||
else:
|
||||
ptr = c_ptr + global_row[:, None] * n_size + local_col_offsets[None, :]
|
||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||
|
||||
tbidx += NUM_SMS
|
||||
if USE_EPILOGUE_SUBTILING:
|
||||
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
|
||||
acc = tl.permute(acc, (0, 2, 1))
|
||||
acc0, acc1 = tl.split(acc)
|
||||
|
||||
processed_tiles += group_num_tiles
|
||||
col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2]
|
||||
col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2]
|
||||
ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :]
|
||||
tl.store(
|
||||
ptr0,
|
||||
acc0.to(c_dtype),
|
||||
mask=row_store_mask[:, None] & col_mask0[None, :],
|
||||
)
|
||||
|
||||
col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :]
|
||||
col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :]
|
||||
ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :]
|
||||
tl.store(
|
||||
ptr1,
|
||||
acc1.to(c_dtype),
|
||||
mask=row_store_mask[:, None] & col_mask1[None, :],
|
||||
)
|
||||
else:
|
||||
ptr = (
|
||||
c_ptr
|
||||
+ global_row[:, None] * n_size
|
||||
+ local_col_offsets[None, :]
|
||||
)
|
||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||
|
||||
tbidx += NUM_SMS
|
||||
|
||||
processed_tiles += group_num_tiles
|
||||
|
||||
|
||||
"""
|
||||
@@ -220,60 +226,62 @@ def _kernel_mg_dx_tma(
|
||||
m_size = tl.load(m_sizes + g)
|
||||
M_end = M_start + m_size
|
||||
|
||||
if m_size <= 0:
|
||||
continue
|
||||
if m_size > 0:
|
||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
group_num_tiles = num_m_tiles * num_k_tiles
|
||||
|
||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
group_num_tiles = num_m_tiles * num_k_tiles
|
||||
while (
|
||||
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
|
||||
):
|
||||
group_index = tbidx - processed_tiles
|
||||
|
||||
while tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles:
|
||||
group_index = tbidx - processed_tiles
|
||||
tile_m_index = group_index % num_m_tiles
|
||||
tile_k_index = group_index // num_m_tiles
|
||||
|
||||
tile_m_index = group_index % num_m_tiles
|
||||
tile_k_index = group_index // num_m_tiles
|
||||
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
|
||||
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
k_offset = tile_k_index * BLOCK_SIZE_K
|
||||
k_remaining_total = K - k_offset
|
||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total
|
||||
|
||||
k_offset = tile_k_index * BLOCK_SIZE_K
|
||||
k_remaining_total = K - k_offset
|
||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
||||
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
|
||||
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
for n_offset in range(0, N, BLOCK_SIZE_N):
|
||||
n_remaining = N - n_offset
|
||||
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
|
||||
|
||||
for n_offset in range(0, N, BLOCK_SIZE_N):
|
||||
n_remaining = N - n_offset
|
||||
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
|
||||
grad_y = grad_output_desc.load([m_offset, n_offset])
|
||||
grad_y_mask = row_mask[:, None] & n_mask[None, :]
|
||||
grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y))
|
||||
|
||||
grad_y = grad_output_desc.load([m_offset, n_offset])
|
||||
grad_y_mask = row_mask[:, None] & n_mask[None, :]
|
||||
grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y))
|
||||
w_tile = w_desc.load([n_offset, k_offset])
|
||||
w_mask = n_mask[:, None] & k_mask[None, :]
|
||||
w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile))
|
||||
|
||||
w_tile = w_desc.load([n_offset, k_offset])
|
||||
w_mask = n_mask[:, None] & k_mask[None, :]
|
||||
w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile))
|
||||
accumulator += tl.dot(grad_y, w_tile)
|
||||
|
||||
accumulator += tl.dot(grad_y, w_tile)
|
||||
local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(
|
||||
0, BLOCK_SIZE_M
|
||||
)
|
||||
row_store_mask = local_row_offsets < m_size
|
||||
global_row = (M_start + local_row_offsets).to(tl.int32)
|
||||
|
||||
local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
row_store_mask = local_row_offsets < m_size
|
||||
global_row = (M_start + local_row_offsets).to(tl.int32)
|
||||
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
|
||||
col_store_mask = col_offsets < K
|
||||
|
||||
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
|
||||
col_store_mask = col_offsets < K
|
||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||
|
||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||
ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :]
|
||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||
|
||||
ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :]
|
||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||
tbidx += NUM_SMS
|
||||
|
||||
tbidx += NUM_SMS
|
||||
|
||||
processed_tiles += group_num_tiles
|
||||
processed_tiles += group_num_tiles
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
@@ -342,29 +350,29 @@ def _kernel_mg_dw_tma(
|
||||
m_size = tl.load(m_sizes + g)
|
||||
M_end = M_start + m_size
|
||||
|
||||
if m_size <= 0:
|
||||
continue
|
||||
if m_size > 0:
|
||||
for m_offset_local in range(0, m_size, BLOCK_SIZE_M):
|
||||
rows_remaining = m_size - m_offset_local
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
|
||||
for m_offset_local in range(0, m_size, BLOCK_SIZE_M):
|
||||
rows_remaining = m_size - m_offset_local
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
m_offset = (M_start + m_offset_local).to(tl.int32)
|
||||
|
||||
m_offset = (M_start + m_offset_local).to(tl.int32)
|
||||
x_block = x_desc.load([m_offset, k_offset])
|
||||
x_mask = row_mask[:, None] & k_mask[None, :]
|
||||
x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block))
|
||||
|
||||
x_block = x_desc.load([m_offset, k_offset])
|
||||
x_mask = row_mask[:, None] & k_mask[None, :]
|
||||
x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block))
|
||||
grad_block = grad_output_desc.load([m_offset, n_offset])
|
||||
grad_mask = row_mask[:, None] & n_mask[None, :]
|
||||
grad_block = tl.where(
|
||||
grad_mask, grad_block, tl.zeros_like(grad_block)
|
||||
)
|
||||
|
||||
grad_block = grad_output_desc.load([m_offset, n_offset])
|
||||
grad_mask = row_mask[:, None] & n_mask[None, :]
|
||||
grad_block = tl.where(grad_mask, grad_block, tl.zeros_like(grad_block))
|
||||
|
||||
contribution = tl.dot(
|
||||
grad_block.to(tl.float32).T,
|
||||
x_block.to(tl.float32),
|
||||
)
|
||||
accumulator += contribution
|
||||
contribution = tl.dot(
|
||||
grad_block.to(tl.float32).T,
|
||||
x_block.to(tl.float32),
|
||||
)
|
||||
accumulator += contribution
|
||||
|
||||
row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N)
|
||||
row_store_mask = row_offsets < N
|
||||
|
||||
Reference in New Issue
Block a user