improve logprob masking and shift in trainer
This commit is contained in:
@@ -53,6 +53,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def transform_logprobs(self, sample):
|
def transform_logprobs(self, sample):
|
||||||
|
"""
|
||||||
|
Transform logprobs to target format for KD training
|
||||||
|
"""
|
||||||
|
|
||||||
logprobs = sample.pop(self.logprobs_field)
|
logprobs = sample.pop(self.logprobs_field)
|
||||||
target_seq_len = len(logprobs)
|
target_seq_len = len(logprobs)
|
||||||
input_seq_len = len(sample["input_ids"])
|
input_seq_len = len(sample["input_ids"])
|
||||||
@@ -62,9 +66,20 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
target_token_ids = []
|
target_token_ids = []
|
||||||
target_mask = []
|
target_mask = []
|
||||||
|
|
||||||
|
if input_padding_len < 0:
|
||||||
|
# logprobs is longer than target_seq_len,
|
||||||
|
# so we need to slice from the left/beginning of logprobs
|
||||||
|
logprobs = logprobs[:-input_seq_len]
|
||||||
|
input_padding_len = 0
|
||||||
|
target_seq_len = input_seq_len
|
||||||
|
|
||||||
# fill with -inf for padding_len tokens for top_k tokens
|
# fill with -inf for padding_len tokens for top_k tokens
|
||||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||||
for _ in range(1, input_padding_len): # start at 1 since this is causal
|
|
||||||
|
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
||||||
|
# otherwise, we need to shift in the trainer
|
||||||
|
shift = 0
|
||||||
|
for _ in range(shift, input_padding_len):
|
||||||
target_logprobs.append([-float("inf")] * top_k)
|
target_logprobs.append([-float("inf")] * top_k)
|
||||||
target_token_ids.append(list(range(top_k)))
|
target_token_ids.append(list(range(top_k)))
|
||||||
target_mask.append([0] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
@@ -73,6 +88,12 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
# TODO also check against sample["labels"]
|
# TODO also check against sample["labels"]
|
||||||
target_mask.append([1] * top_k)
|
target_mask.append([1] * top_k)
|
||||||
|
|
||||||
|
for position in range(input_padding_len, input_seq_len):
|
||||||
|
if sample["labels"][position] == -100:
|
||||||
|
target_mask.append([0] * top_k)
|
||||||
|
else:
|
||||||
|
target_mask.append([1] * top_k)
|
||||||
|
|
||||||
for _, token_pos_logprobs in enumerate(logprobs):
|
for _, token_pos_logprobs in enumerate(logprobs):
|
||||||
# Initialize collections for logprobs and token_ids
|
# Initialize collections for logprobs and token_ids
|
||||||
position_logprobs = []
|
position_logprobs = []
|
||||||
@@ -120,10 +141,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
target_logprobs.append(position_logprobs_scaled)
|
target_logprobs.append(position_logprobs_scaled)
|
||||||
target_token_ids.append(position_token_ids)
|
target_token_ids.append(position_token_ids)
|
||||||
|
|
||||||
# since we started at index 1 for causal, we need one more padding token
|
if shift == 1:
|
||||||
target_logprobs.append([-float("inf")] * top_k)
|
# since we started at index 1 for causal, we need one more padding token
|
||||||
target_token_ids.append(list(range(top_k)))
|
target_logprobs.append([-float("inf")] * top_k)
|
||||||
target_mask.append([0] * top_k)
|
target_token_ids.append(list(range(top_k)))
|
||||||
|
target_mask.append([0] * top_k)
|
||||||
|
|
||||||
# Update sample with transformed logprobs
|
# Update sample with transformed logprobs
|
||||||
sample["target_logprobs"] = target_logprobs
|
sample["target_logprobs"] = target_logprobs
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
inputs,
|
inputs,
|
||||||
return_outputs=False,
|
return_outputs=False,
|
||||||
num_items_in_batch=None,
|
num_items_in_batch=None,
|
||||||
shift_targets=False,
|
shift_targets=True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||||
|
|||||||
Reference in New Issue
Block a user