Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
ca379405c1 use narrow as a view on the student logits instead of slicing 2025-02-04 09:34:26 -05:00

View File

@@ -67,9 +67,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
outputs = model(**inputs)
# FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
shift_logits = outputs["logits"].narrow(1, 0, seq_len - 1).contiguous()
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()