don't shift student logits for kd

This commit is contained in:
Wing Lian
2025-01-15 01:07:48 -05:00
parent 483defb9ae
commit 04efcb102f

View File

@@ -70,7 +70,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
if shift_targets:
shift_logits = student_logits[..., :-1, :].contiguous()
# shift_logits = student_logits[..., :-1, :].contiguous()
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()