no torch.tensor

This commit is contained in:
Wing Lian
2024-12-21 14:00:01 -05:00
parent dc90c93894
commit 081928e55b

View File

@@ -107,9 +107,9 @@ def kd_forward_kernel(
exp_val = tl.exp(student_logits_k[k] - logsumexp_val)
exp_sum += exp_val
# safe check
if exp_sum == 0.0:
exp_sum = 1e-8
logsumexp_val = logsumexp_val + tl.log(torch.tensor(exp_sum))
epsilon = 1e-8 # Small constant to prevent log(0)
exp_sum = tl.where(exp_sum == 0.0, epsilon, exp_sum)
logsumexp_val = logsumexp_val + tl.log(exp_sum)
# compute partial kl
# sum_{k in valid} p^T_k (log p^T_k - log p^S_k)