no where support

This commit is contained in:
Wing Lian
2024-12-21 13:21:54 -05:00
parent bc3326a808
commit 48ccf55752

View File

@@ -55,19 +55,20 @@ def fwd_kl_topk_kernel(
# load student logprobs
s_lp = tl.load(s_row_ptr + k_offset * stride_sk, mask=mask, other=-float("inf"))
# load mask => bool
valid = tl.load(
m_row_ptr + k_offset * stride_mk, mask=mask, other=0
) # 0 or 1 => bool
valid_bool = valid.to(tl.int1)
# load mask => bool (0 or 1)
valid = tl.load(m_row_ptr + k_offset * stride_mk, mask=mask, other=0)
valid_f32 = valid.to(tl.float32)
# teacher probs
t_p = tl.exp(t_lp)
# local_kl = p^T * (lp^T - lp^S) * valid
# local_kl = p^T * (lp^T - lp^S)
local_kl = t_p * (t_lp - s_lp)
# sum only over valid positions
kl_sum += tl.sum(local_kl, where=valid_bool)
# multiply by valid_f32 to ignore padded or invalid positions
local_kl *= valid_f32
# sum over the chunk
kl_sum += tl.sum(local_kl, where=mask)
# store rowwise result
tl.store(loss_out_ptr + row_id * stride_loss_n, kl_sum)