handle token/logprob shifting

This commit is contained in:
Wing Lian
2024-12-30 11:21:19 -05:00
parent 69ed25e82c
commit 55b33cc44d
2 changed files with 20 additions and 9 deletions

View File

@@ -41,7 +41,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
self._signature_columns += columns_to_add
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None,
shift_targets=False,
):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
@@ -65,16 +70,22 @@ class AxolotlKDTrainer(AxolotlTrainer):
# FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
shift_logits = student_logits[..., :-1, :].contiguous()
shift_target_logprobs = target_logprobs[..., 1:, :].contiguous()
shift_target_token_ids = target_token_ids[..., 1:, :].contiguous()
shift_target_mask = target_mask[..., 1:, :].contiguous()
if shift_targets:
shift_logits = student_logits[..., :-1, :].contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
else:
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs.contiguous()
target_token_ids_for_loss = target_token_ids.contiguous()
target_mask_for_loss = target_mask.contiguous()
loss_kd = topk_kd_loss(
shift_logits,
shift_target_token_ids,
shift_target_logprobs,
shift_target_mask,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
)

View File

@@ -502,7 +502,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
for _ in range(input_padding_len):
for _ in range(1, input_padding_len): # start at 1 since this is causal
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)