diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index f99f2ca28..6a9989e7e 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -67,9 +67,8 @@ class AxolotlKDTrainer(AxolotlTrainer): outputs = model(**inputs) # FIXME: account for tokenizer.padding_side - student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous() + shift_logits = outputs["logits"].narrow(1, 0, seq_len - 1).contiguous() - shift_logits = student_logits.contiguous() target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() target_mask_for_loss = target_mask[..., 1:, :].contiguous()