use narrow as a view on the student logits instead of slicing
This commit is contained in:
@@ -67,9 +67,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
outputs = model(**inputs)
|
||||
|
||||
# FIXME: account for tokenizer.padding_side
|
||||
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
|
||||
shift_logits = outputs["logits"].narrow(1, 0, seq_len - 1).contiguous()
|
||||
|
||||
shift_logits = student_logits.contiguous()
|
||||
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||
|
||||
Reference in New Issue
Block a user