handle token/logprob shifting
This commit is contained in:
@@ -41,7 +41,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
self._signature_columns += columns_to_add
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
shift_targets=False,
|
||||
):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
@@ -65,16 +70,22 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
# FIXME: account for tokenizer.padding_side
|
||||
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
|
||||
|
||||
shift_logits = student_logits[..., :-1, :].contiguous()
|
||||
shift_target_logprobs = target_logprobs[..., 1:, :].contiguous()
|
||||
shift_target_token_ids = target_token_ids[..., 1:, :].contiguous()
|
||||
shift_target_mask = target_mask[..., 1:, :].contiguous()
|
||||
if shift_targets:
|
||||
shift_logits = student_logits[..., :-1, :].contiguous()
|
||||
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||
else:
|
||||
shift_logits = student_logits.contiguous()
|
||||
target_logprobs_for_loss = target_logprobs.contiguous()
|
||||
target_token_ids_for_loss = target_token_ids.contiguous()
|
||||
target_mask_for_loss = target_mask.contiguous()
|
||||
|
||||
loss_kd = topk_kd_loss(
|
||||
shift_logits,
|
||||
shift_target_token_ids,
|
||||
shift_target_logprobs,
|
||||
shift_target_mask,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
target_mask_for_loss,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
kd_temperature=self.args.kd_temperature,
|
||||
)
|
||||
|
||||
@@ -502,7 +502,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
|
||||
# fill with -inf for padding_len tokens for top_k tokens
|
||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||
for _ in range(input_padding_len):
|
||||
for _ in range(1, input_padding_len): # start at 1 since this is causal
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
Reference in New Issue
Block a user