handle token/logprob shifting
This commit is contained in:
@@ -41,7 +41,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
self._signature_columns += columns_to_add
|
self._signature_columns += columns_to_add
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
self,
|
||||||
|
model,
|
||||||
|
inputs,
|
||||||
|
return_outputs=False,
|
||||||
|
num_items_in_batch=None,
|
||||||
|
shift_targets=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||||
@@ -65,16 +70,22 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
# FIXME: account for tokenizer.padding_side
|
# FIXME: account for tokenizer.padding_side
|
||||||
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
|
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
|
||||||
|
|
||||||
shift_logits = student_logits[..., :-1, :].contiguous()
|
if shift_targets:
|
||||||
shift_target_logprobs = target_logprobs[..., 1:, :].contiguous()
|
shift_logits = student_logits[..., :-1, :].contiguous()
|
||||||
shift_target_token_ids = target_token_ids[..., 1:, :].contiguous()
|
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
||||||
shift_target_mask = target_mask[..., 1:, :].contiguous()
|
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||||
|
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||||
|
else:
|
||||||
|
shift_logits = student_logits.contiguous()
|
||||||
|
target_logprobs_for_loss = target_logprobs.contiguous()
|
||||||
|
target_token_ids_for_loss = target_token_ids.contiguous()
|
||||||
|
target_mask_for_loss = target_mask.contiguous()
|
||||||
|
|
||||||
loss_kd = topk_kd_loss(
|
loss_kd = topk_kd_loss(
|
||||||
shift_logits,
|
shift_logits,
|
||||||
shift_target_token_ids,
|
target_token_ids_for_loss,
|
||||||
shift_target_logprobs,
|
target_logprobs_for_loss,
|
||||||
shift_target_mask,
|
target_mask_for_loss,
|
||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
kd_temperature=self.args.kd_temperature,
|
kd_temperature=self.args.kd_temperature,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -502,7 +502,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
|
|
||||||
# fill with -inf for padding_len tokens for top_k tokens
|
# fill with -inf for padding_len tokens for top_k tokens
|
||||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||||
for _ in range(input_padding_len):
|
for _ in range(1, input_padding_len): # start at 1 since this is causal
|
||||||
target_logprobs.append([-float("inf")] * top_k)
|
target_logprobs.append([-float("inf")] * top_k)
|
||||||
target_token_ids.append(list(range(top_k)))
|
target_token_ids.append(list(range(top_k)))
|
||||||
target_mask.append([0] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
|
|||||||
Reference in New Issue
Block a user