don't shift student logits for kd
This commit is contained in:
@@ -70,7 +70,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
|
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
|
||||||
|
|
||||||
if shift_targets:
|
if shift_targets:
|
||||||
shift_logits = student_logits[..., :-1, :].contiguous()
|
# shift_logits = student_logits[..., :-1, :].contiguous()
|
||||||
|
shift_logits = student_logits.contiguous()
|
||||||
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
||||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||||
|
|||||||
Reference in New Issue
Block a user