From e376e00386c8cddf342a7406f34880a1ec380dc0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 21 Dec 2024 16:47:11 -0500 Subject: [PATCH] don't use triton for now --- src/axolotl/core/trainers/kd.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/axolotl/core/trainers/kd.py b/src/axolotl/core/trainers/kd.py index f90e7d02e..893c529d8 100644 --- a/src/axolotl/core/trainers/kd.py +++ b/src/axolotl/core/trainers/kd.py @@ -100,18 +100,22 @@ class AxolotlKDTrainer(AxolotlTrainer): # Slice or gather student logits to match teacher seq len # e.g.: teacher_seq_len = target_token_ids.shape[1] - student_logits_for_kd = student_logits[:, :teacher_seq_len, :] # [B, seq_len, vocab_size] + student_logits_for_kd = student_logits[ + :, :teacher_seq_len, : + ] # [B, seq_len, vocab_size] # GATHER top-K from student student_logits_topk = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids # same shape [B, seq_len, K] + student_logits_for_kd, + dim=-1, + index=target_token_ids, # same shape [B, seq_len, K] ) # Now call the Triton-based KD loss kd_sum = kd_loss_triton( student_logits_topk, - target_logprobs, # teacher logprobs [B, seq_len, K] - target_mask, # mask [B, seq_len, K] + target_logprobs, # teacher logprobs [B, seq_len, K] + target_mask, # mask [B, seq_len, K] ) # Normalize however you want @@ -140,9 +144,9 @@ class AxolotlKDTrainer(AxolotlTrainer): Subclass and override for custom behavior. """ - return self.compute_loss_w_triton( - model, inputs, return_outputs, num_items_in_batch - ) + # return self.compute_loss_w_triton( + # model, inputs, return_outputs, num_items_in_batch + # ) target_logprobs = inputs.pop("target_logprobs") target_token_ids = inputs.pop("target_token_ids")