diff --git a/src/axolotl/integrations/kd/kernels/kd.py b/src/axolotl/integrations/kd/kernels/kd.py index 6e0a87735..2cd19aa64 100644 --- a/src/axolotl/integrations/kd/kernels/kd.py +++ b/src/axolotl/integrations/kd/kernels/kd.py @@ -107,9 +107,9 @@ def kd_forward_kernel( exp_val = tl.exp(student_logits_k[k] - logsumexp_val) exp_sum += exp_val # safe check - if exp_sum == 0.0: - exp_sum = 1e-8 - logsumexp_val = logsumexp_val + tl.log(torch.tensor(exp_sum)) + epsilon = 1e-8 # Small constant to prevent log(0) + exp_sum = tl.where(exp_sum == 0.0, epsilon, exp_sum) + logsumexp_val = logsumexp_val + tl.log(exp_sum) # compute partial kl # sum_{k in valid} p^T_k (log p^T_k - log p^S_k)