use triton for kd-loss in trainer
This commit is contained in:
@@ -18,8 +18,8 @@ KD trainer
|
|||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
|
||||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||||
|
from .topk_logprob.forward_kl_triton import loss as topk_kd_loss_triton
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKDTrainer(AxolotlTrainer):
|
class AxolotlKDTrainer(AxolotlTrainer):
|
||||||
@@ -85,7 +85,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
loss_kd = topk_kd_loss(
|
loss_kd = topk_kd_loss_triton(
|
||||||
shift_logits,
|
shift_logits,
|
||||||
target_token_ids_for_loss,
|
target_token_ids_for_loss,
|
||||||
target_logprobs_for_loss,
|
target_logprobs_for_loss,
|
||||||
|
|||||||
Reference in New Issue
Block a user