use triton for kd-loss in trainer

This commit is contained in:
Wing Lian
2025-02-24 22:58:35 -05:00
parent 75e1480c10
commit e82268e580

View File

@@ -18,8 +18,8 @@ KD trainer
from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
from .topk_logprob.forward_kl_triton import loss as topk_kd_loss_triton
class AxolotlKDTrainer(AxolotlTrainer):
@@ -85,7 +85,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
num_items_in_batch=num_items_in_batch,
)
else:
loss_kd = topk_kd_loss(
loss_kd = topk_kd_loss_triton(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,