use narrow as a view on the student logits instead of slicing
This commit is contained in:
@@ -67,9 +67,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
# FIXME: account for tokenizer.padding_side
|
# FIXME: account for tokenizer.padding_side
|
||||||
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
|
shift_logits = outputs["logits"].narrow(1, 0, seq_len - 1).contiguous()
|
||||||
|
|
||||||
shift_logits = student_logits.contiguous()
|
|
||||||
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
||||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||||
|
|||||||
Reference in New Issue
Block a user