no where support
This commit is contained in:
@@ -55,19 +55,20 @@ def fwd_kl_topk_kernel(
|
||||
# load student logprobs
|
||||
s_lp = tl.load(s_row_ptr + k_offset * stride_sk, mask=mask, other=-float("inf"))
|
||||
|
||||
# load mask => bool
|
||||
valid = tl.load(
|
||||
m_row_ptr + k_offset * stride_mk, mask=mask, other=0
|
||||
) # 0 or 1 => bool
|
||||
valid_bool = valid.to(tl.int1)
|
||||
# load mask => bool (0 or 1)
|
||||
valid = tl.load(m_row_ptr + k_offset * stride_mk, mask=mask, other=0)
|
||||
valid_f32 = valid.to(tl.float32)
|
||||
|
||||
# teacher probs
|
||||
t_p = tl.exp(t_lp)
|
||||
|
||||
# local_kl = p^T * (lp^T - lp^S) * valid
|
||||
# local_kl = p^T * (lp^T - lp^S)
|
||||
local_kl = t_p * (t_lp - s_lp)
|
||||
# sum only over valid positions
|
||||
kl_sum += tl.sum(local_kl, where=valid_bool)
|
||||
# multiply by valid_f32 to ignore padded or invalid positions
|
||||
local_kl *= valid_f32
|
||||
|
||||
# sum over the chunk
|
||||
kl_sum += tl.sum(local_kl, where=mask)
|
||||
|
||||
# store rowwise result
|
||||
tl.store(loss_out_ptr + row_id * stride_loss_n, kl_sum)
|
||||
|
||||
Reference in New Issue
Block a user