handle token/logprob shifting

This commit is contained in:
Wing Lian
2024-12-30 11:21:19 -05:00
parent 69ed25e82c
commit 55b33cc44d
2 changed files with 20 additions and 9 deletions

View File

@@ -41,7 +41,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
self._signature_columns += columns_to_add self._signature_columns += columns_to_add
def compute_loss( def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None,
shift_targets=False,
): ):
""" """
How the loss is computed by Trainer. By default, all models return the loss in the first element. How the loss is computed by Trainer. By default, all models return the loss in the first element.
@@ -65,16 +70,22 @@ class AxolotlKDTrainer(AxolotlTrainer):
# FIXME: account for tokenizer.padding_side # FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, :seq_len, :].contiguous() student_logits = outputs["logits"][:, :seq_len, :].contiguous()
shift_logits = student_logits[..., :-1, :].contiguous() if shift_targets:
shift_target_logprobs = target_logprobs[..., 1:, :].contiguous() shift_logits = student_logits[..., :-1, :].contiguous()
shift_target_token_ids = target_token_ids[..., 1:, :].contiguous() target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
shift_target_mask = target_mask[..., 1:, :].contiguous() target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
else:
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs.contiguous()
target_token_ids_for_loss = target_token_ids.contiguous()
target_mask_for_loss = target_mask.contiguous()
loss_kd = topk_kd_loss( loss_kd = topk_kd_loss(
shift_logits, shift_logits,
shift_target_token_ids, target_token_ids_for_loss,
shift_target_logprobs, target_logprobs_for_loss,
shift_target_mask, target_mask_for_loss,
num_items_in_batch=num_items_in_batch, num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature, kd_temperature=self.args.kd_temperature,
) )

View File

@@ -502,7 +502,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# fill with -inf for padding_len tokens for top_k tokens # fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf # extend target_logprobs with a padding_len x top_k 2D list filled with -inf
for _ in range(input_padding_len): for _ in range(1, input_padding_len): # start at 1 since this is causal
target_logprobs.append([-float("inf")] * top_k) target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k))) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k) target_mask.append([0] * top_k)