From e82268e580ffdc5efa8269692127654a2f91bc21 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 24 Feb 2025 22:58:35 -0500 Subject: [PATCH] use triton for kd-loss in trainer --- src/axolotl/integrations/kd/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index f99f2ca28..e57731461 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -18,8 +18,8 @@ KD trainer from axolotl.core.trainers.base import AxolotlTrainer -from .topk_logprob.forward_kl import loss as topk_kd_loss from .topk_logprob.forward_kl import topk_kd_loss_with_zscore +from .topk_logprob.forward_kl_triton import loss as topk_kd_loss_triton class AxolotlKDTrainer(AxolotlTrainer): @@ -85,7 +85,7 @@ class AxolotlKDTrainer(AxolotlTrainer): num_items_in_batch=num_items_in_batch, ) else: - loss_kd = topk_kd_loss( + loss_kd = topk_kd_loss_triton( shift_logits, target_token_ids_for_loss, target_logprobs_for_loss,