From 0da2b7c7ccc00f1d25ddc0f3ef8899a2a9fc41a6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 21 Dec 2024 13:52:31 -0500 Subject: [PATCH] no torch.exp inside triton kernel --- src/axolotl/integrations/kd/kernels/kd.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/axolotl/integrations/kd/kernels/kd.py b/src/axolotl/integrations/kd/kernels/kd.py index 058cfe334..36aacfef2 100644 --- a/src/axolotl/integrations/kd/kernels/kd.py +++ b/src/axolotl/integrations/kd/kernels/kd.py @@ -114,7 +114,8 @@ def kd_forward_kernel( exp_sum = 0.0 for k in range(K): if valid_k[k] != 0: # if valid - exp_sum += float(torch.exp(student_logits_k[k] - logsumexp_val)) + exp_val = tl.exp(student_logits_k[k] - logsumexp_val) + exp_sum += exp_val # safe check if exp_sum == 0.0: exp_sum = 1e-8 @@ -125,7 +126,7 @@ def kd_forward_kernel( # teacher_probs_k = exp(teacher_logprobs_k) for k in range(K): if valid_k[k] != 0: # only valid tokens - teacher_prob = float(torch.exp(teacher_logprobs_k[k])) + teacher_prob = tl.exp(teacher_logprobs_k[k]) student_logprob = student_logits_k[k] - logsumexp_val kd_val = teacher_prob * (teacher_logprobs_k[k] - student_logprob) kd_reg += kd_val