From ca379405c1c413f7b170afa8b9f40f7f11271aac Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 4 Feb 2025 09:34:26 -0500 Subject: [PATCH] use narrow as a view on the student logits instead of slicing --- src/axolotl/integrations/kd/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index f99f2ca28..6a9989e7e 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -67,9 +67,8 @@ class AxolotlKDTrainer(AxolotlTrainer): outputs = model(**inputs) # FIXME: account for tokenizer.padding_side - student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous() + shift_logits = outputs["logits"].narrow(1, 0, seq_len - 1).contiguous() - shift_logits = student_logits.contiguous() target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() target_mask_for_loss = target_mask[..., 1:, :].contiguous()