no torch.tensor
This commit is contained in:
@@ -107,9 +107,9 @@ def kd_forward_kernel(
|
||||
exp_val = tl.exp(student_logits_k[k] - logsumexp_val)
|
||||
exp_sum += exp_val
|
||||
# safe check
|
||||
if exp_sum == 0.0:
|
||||
exp_sum = 1e-8
|
||||
logsumexp_val = logsumexp_val + tl.log(torch.tensor(exp_sum))
|
||||
epsilon = 1e-8 # Small constant to prevent log(0)
|
||||
exp_sum = tl.where(exp_sum == 0.0, epsilon, exp_sum)
|
||||
logsumexp_val = logsumexp_val + tl.log(exp_sum)
|
||||
|
||||
# compute partial kl
|
||||
# sum_{k in valid} p^T_k (log p^T_k - log p^S_k)
|
||||
|
||||
Reference in New Issue
Block a user