no torch.exp inside triton kernel

This commit is contained in:
Wing Lian
2024-12-21 13:52:31 -05:00
parent 119d586cf4
commit 18a46c338a

View File

@@ -114,7 +114,8 @@ def kd_forward_kernel(
exp_sum = 0.0
for k in range(K):
if valid_k[k] != 0: # if valid
exp_sum += float(torch.exp(student_logits_k[k] - logsumexp_val))
exp_val = tl.exp(student_logits_k[k] - logsumexp_val)
exp_sum += exp_val
# safe check
if exp_sum == 0.0:
exp_sum = 1e-8
@@ -125,7 +126,7 @@ def kd_forward_kernel(
# teacher_probs_k = exp(teacher_logprobs_k)
for k in range(K):
if valid_k[k] != 0: # only valid tokens
teacher_prob = float(torch.exp(teacher_logprobs_k[k]))
teacher_prob = tl.exp(teacher_logprobs_k[k])
student_logprob = student_logits_k[k] - logsumexp_val
kd_val = teacher_prob * (teacher_logprobs_k[k] - student_logprob)
kd_reg += kd_val