diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index f99f2ca28..e57731461 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -18,8 +18,8 @@ KD trainer from axolotl.core.trainers.base import AxolotlTrainer -from .topk_logprob.forward_kl import loss as topk_kd_loss from .topk_logprob.forward_kl import topk_kd_loss_with_zscore +from .topk_logprob.forward_kl_triton import loss as topk_kd_loss_triton class AxolotlKDTrainer(AxolotlTrainer): @@ -85,7 +85,7 @@ class AxolotlKDTrainer(AxolotlTrainer): num_items_in_batch=num_items_in_batch, ) else: - loss_kd = topk_kd_loss( + loss_kd = topk_kd_loss_triton( shift_logits, target_token_ids_for_loss, target_logprobs_for_loss,