diff --git a/src/axolotl/integrations/kd/kernels/kd.py b/src/axolotl/integrations/kd/kernels/kd.py index c34d025ac..5d2ff19cf 100644 --- a/src/axolotl/integrations/kd/kernels/kd.py +++ b/src/axolotl/integrations/kd/kernels/kd.py @@ -55,19 +55,20 @@ def fwd_kl_topk_kernel( # load student logprobs s_lp = tl.load(s_row_ptr + k_offset * stride_sk, mask=mask, other=-float("inf")) - # load mask => bool - valid = tl.load( - m_row_ptr + k_offset * stride_mk, mask=mask, other=0 - ) # 0 or 1 => bool - valid_bool = valid.to(tl.int1) + # load mask => bool (0 or 1) + valid = tl.load(m_row_ptr + k_offset * stride_mk, mask=mask, other=0) + valid_f32 = valid.to(tl.float32) # teacher probs t_p = tl.exp(t_lp) - # local_kl = p^T * (lp^T - lp^S) * valid + # local_kl = p^T * (lp^T - lp^S) local_kl = t_p * (t_lp - s_lp) - # sum only over valid positions - kl_sum += tl.sum(local_kl, where=valid_bool) + # multiply by valid_f32 to ignore padded or invalid positions + local_kl *= valid_f32 + + # sum over the chunk + kl_sum += tl.sum(local_kl, where=mask) # store rowwise result tl.store(loss_out_ptr + row_id * stride_loss_n, kl_sum)