From 081928e55bfd7bafaf7809a305f3f879baa13513 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 21 Dec 2024 14:00:01 -0500 Subject: [PATCH] no torch.tensor --- src/axolotl/integrations/kd/kernels/kd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/integrations/kd/kernels/kd.py b/src/axolotl/integrations/kd/kernels/kd.py index 6e0a87735..2cd19aa64 100644 --- a/src/axolotl/integrations/kd/kernels/kd.py +++ b/src/axolotl/integrations/kd/kernels/kd.py @@ -107,9 +107,9 @@ def kd_forward_kernel( exp_val = tl.exp(student_logits_k[k] - logsumexp_val) exp_sum += exp_val # safe check - if exp_sum == 0.0: - exp_sum = 1e-8 - logsumexp_val = logsumexp_val + tl.log(torch.tensor(exp_sum)) + epsilon = 1e-8 # Small constant to prevent log(0) + exp_sum = tl.where(exp_sum == 0.0, epsilon, exp_sum) + logsumexp_val = logsumexp_val + tl.log(exp_sum) # compute partial kl # sum_{k in valid} p^T_k (log p^T_k - log p^S_k)