no torch.exp inside triton kernel
This commit is contained in:
@@ -114,7 +114,8 @@ def kd_forward_kernel(
|
|||||||
exp_sum = 0.0
|
exp_sum = 0.0
|
||||||
for k in range(K):
|
for k in range(K):
|
||||||
if valid_k[k] != 0: # if valid
|
if valid_k[k] != 0: # if valid
|
||||||
exp_sum += float(torch.exp(student_logits_k[k] - logsumexp_val))
|
exp_val = tl.exp(student_logits_k[k] - logsumexp_val)
|
||||||
|
exp_sum += exp_val
|
||||||
# safe check
|
# safe check
|
||||||
if exp_sum == 0.0:
|
if exp_sum == 0.0:
|
||||||
exp_sum = 1e-8
|
exp_sum = 1e-8
|
||||||
@@ -125,7 +126,7 @@ def kd_forward_kernel(
|
|||||||
# teacher_probs_k = exp(teacher_logprobs_k)
|
# teacher_probs_k = exp(teacher_logprobs_k)
|
||||||
for k in range(K):
|
for k in range(K):
|
||||||
if valid_k[k] != 0: # only valid tokens
|
if valid_k[k] != 0: # only valid tokens
|
||||||
teacher_prob = float(torch.exp(teacher_logprobs_k[k]))
|
teacher_prob = tl.exp(teacher_logprobs_k[k])
|
||||||
student_logprob = student_logits_k[k] - logsumexp_val
|
student_logprob = student_logits_k[k] - logsumexp_val
|
||||||
kd_val = teacher_prob * (teacher_logprobs_k[k] - student_logprob)
|
kd_val = teacher_prob * (teacher_logprobs_k[k] - student_logprob)
|
||||||
kd_reg += kd_val
|
kd_reg += kd_val
|
||||||
|
|||||||
Reference in New Issue
Block a user