no torch.exp inside triton kernel
This commit is contained in:
@@ -114,7 +114,8 @@ def kd_forward_kernel(
|
||||
exp_sum = 0.0
|
||||
for k in range(K):
|
||||
if valid_k[k] != 0: # if valid
|
||||
exp_sum += float(torch.exp(student_logits_k[k] - logsumexp_val))
|
||||
exp_val = tl.exp(student_logits_k[k] - logsumexp_val)
|
||||
exp_sum += exp_val
|
||||
# safe check
|
||||
if exp_sum == 0.0:
|
||||
exp_sum = 1e-8
|
||||
@@ -125,7 +126,7 @@ def kd_forward_kernel(
|
||||
# teacher_probs_k = exp(teacher_logprobs_k)
|
||||
for k in range(K):
|
||||
if valid_k[k] != 0: # only valid tokens
|
||||
teacher_prob = float(torch.exp(teacher_logprobs_k[k]))
|
||||
teacher_prob = tl.exp(teacher_logprobs_k[k])
|
||||
student_logprob = student_logits_k[k] - logsumexp_val
|
||||
kd_val = teacher_prob * (teacher_logprobs_k[k] - student_logprob)
|
||||
kd_reg += kd_val
|
||||
|
||||
Reference in New Issue
Block a user