improve logprob masking and shift in trainer

This commit is contained in:
Wing Lian
2025-01-14 20:10:05 -05:00
parent 7232cbdeab
commit 510cf45317
2 changed files with 28 additions and 6 deletions

View File

@@ -53,6 +53,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
) )
def transform_logprobs(self, sample): def transform_logprobs(self, sample):
"""
Transform logprobs to target format for KD training
"""
logprobs = sample.pop(self.logprobs_field) logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs) target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"]) input_seq_len = len(sample["input_ids"])
@@ -62,9 +66,20 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_token_ids = [] target_token_ids = []
target_mask = [] target_mask = []
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
target_seq_len = input_seq_len
# fill with -inf for padding_len tokens for top_k tokens # fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf # extend target_logprobs with a padding_len x top_k 2D list filled with -inf
for _ in range(1, input_padding_len): # start at 1 since this is causal
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer
shift = 0
for _ in range(shift, input_padding_len):
target_logprobs.append([-float("inf")] * top_k) target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k))) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
@@ -73,6 +88,12 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# TODO also check against sample["labels"] # TODO also check against sample["labels"]
target_mask.append([1] * top_k) target_mask.append([1] * top_k)
for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100:
target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs): for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids # Initialize collections for logprobs and token_ids
position_logprobs = [] position_logprobs = []
@@ -120,10 +141,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_logprobs.append(position_logprobs_scaled) target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids) target_token_ids.append(position_token_ids)
# since we started at index 1 for causal, we need one more padding token if shift == 1:
target_logprobs.append([-float("inf")] * top_k) # since we started at index 1 for causal, we need one more padding token
target_token_ids.append(list(range(top_k))) target_logprobs.append([-float("inf")] * top_k)
target_mask.append([0] * top_k) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
# Update sample with transformed logprobs # Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs sample["target_logprobs"] = target_logprobs

View File

@@ -45,7 +45,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
inputs, inputs,
return_outputs=False, return_outputs=False,
num_items_in_batch=None, num_items_in_batch=None,
shift_targets=False, shift_targets=True,
): ):
""" """
How the loss is computed by Trainer. By default, all models return the loss in the first element. How the loss is computed by Trainer. By default, all models return the loss in the first element.