use narrow as a view on the student logits instead of slicing

This commit is contained in:
Wing Lian
2025-02-04 09:34:26 -05:00
parent 158330ab60
commit ca379405c1

View File

@@ -67,9 +67,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
outputs = model(**inputs)
# FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
shift_logits = outputs["logits"].narrow(1, 0, seq_len - 1).contiguous()
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()