From 04efcb102fc6328635ff1e96f97c6ece46f26b6f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 15 Jan 2025 01:07:48 -0500 Subject: [PATCH] don't shift student logits for kd --- src/axolotl/integrations/kd/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 1aa1df452..9eac4cc1d 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -70,7 +70,8 @@ class AxolotlKDTrainer(AxolotlTrainer): student_logits = outputs["logits"][:, :seq_len, :].contiguous() if shift_targets: - shift_logits = student_logits[..., :-1, :].contiguous() + # shift_logits = student_logits[..., :-1, :].contiguous() + shift_logits = student_logits.contiguous() target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() target_mask_for_loss = target_mask[..., 1:, :].contiguous()