diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 1aa1df452..9eac4cc1d 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -70,7 +70,8 @@ class AxolotlKDTrainer(AxolotlTrainer): student_logits = outputs["logits"][:, :seq_len, :].contiguous() if shift_targets: - shift_logits = student_logits[..., :-1, :].contiguous() + # shift_logits = student_logits[..., :-1, :].contiguous() + shift_logits = student_logits.contiguous() target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() target_mask_for_loss = target_mask[..., 1:, :].contiguous()