diff --git a/src/axolotl/integrations/kd/kernels/kd.py b/src/axolotl/integrations/kd/kernels/kd.py index 058cfe334..36aacfef2 100644 --- a/src/axolotl/integrations/kd/kernels/kd.py +++ b/src/axolotl/integrations/kd/kernels/kd.py @@ -114,7 +114,8 @@ def kd_forward_kernel( exp_sum = 0.0 for k in range(K): if valid_k[k] != 0: # if valid - exp_sum += float(torch.exp(student_logits_k[k] - logsumexp_val)) + exp_val = tl.exp(student_logits_k[k] - logsumexp_val) + exp_sum += exp_val # safe check if exp_sum == 0.0: exp_sum = 1e-8 @@ -125,7 +126,7 @@ def kd_forward_kernel( # teacher_probs_k = exp(teacher_logprobs_k) for k in range(K): if valid_k[k] != 0: # only valid tokens - teacher_prob = float(torch.exp(teacher_logprobs_k[k])) + teacher_prob = tl.exp(teacher_logprobs_k[k]) student_logprob = student_logits_k[k] - logsumexp_val kd_val = teacher_prob * (teacher_logprobs_k[k] - student_logprob) kd_reg += kd_val