no where support
This commit is contained in:
@@ -55,19 +55,20 @@ def fwd_kl_topk_kernel(
|
|||||||
# load student logprobs
|
# load student logprobs
|
||||||
s_lp = tl.load(s_row_ptr + k_offset * stride_sk, mask=mask, other=-float("inf"))
|
s_lp = tl.load(s_row_ptr + k_offset * stride_sk, mask=mask, other=-float("inf"))
|
||||||
|
|
||||||
# load mask => bool
|
# load mask => bool (0 or 1)
|
||||||
valid = tl.load(
|
valid = tl.load(m_row_ptr + k_offset * stride_mk, mask=mask, other=0)
|
||||||
m_row_ptr + k_offset * stride_mk, mask=mask, other=0
|
valid_f32 = valid.to(tl.float32)
|
||||||
) # 0 or 1 => bool
|
|
||||||
valid_bool = valid.to(tl.int1)
|
|
||||||
|
|
||||||
# teacher probs
|
# teacher probs
|
||||||
t_p = tl.exp(t_lp)
|
t_p = tl.exp(t_lp)
|
||||||
|
|
||||||
# local_kl = p^T * (lp^T - lp^S) * valid
|
# local_kl = p^T * (lp^T - lp^S)
|
||||||
local_kl = t_p * (t_lp - s_lp)
|
local_kl = t_p * (t_lp - s_lp)
|
||||||
# sum only over valid positions
|
# multiply by valid_f32 to ignore padded or invalid positions
|
||||||
kl_sum += tl.sum(local_kl, where=valid_bool)
|
local_kl *= valid_f32
|
||||||
|
|
||||||
|
# sum over the chunk
|
||||||
|
kl_sum += tl.sum(local_kl, where=mask)
|
||||||
|
|
||||||
# store rowwise result
|
# store rowwise result
|
||||||
tl.store(loss_out_ptr + row_id * stride_loss_n, kl_sum)
|
tl.store(loss_out_ptr + row_id * stride_loss_n, kl_sum)
|
||||||
|
|||||||
Reference in New Issue
Block a user