Compare commits

...

2 Commits

Author SHA1 Message Date
Wing Lian
35a84f2cb8 more fixes 2025-01-14 22:47:49 -05:00
Wing Lian
510cf45317 improve logprob masking and shift in trainer 2025-01-14 22:47:48 -05:00
3 changed files with 33 additions and 11 deletions

View File

@@ -53,6 +53,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
) )
def transform_logprobs(self, sample): def transform_logprobs(self, sample):
"""
Transform logprobs to target format for KD training
"""
logprobs = sample.pop(self.logprobs_field) logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs) target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"]) input_seq_len = len(sample["input_ids"])
@@ -62,16 +66,33 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_token_ids = [] target_token_ids = []
target_mask = [] target_mask = []
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
# target_seq_len = input_seq_len
# fill with -inf for padding_len tokens for top_k tokens # fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf # extend target_logprobs with a padding_len x top_k 2D list filled with -inf
for _ in range(1, input_padding_len): # start at 1 since this is causal
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer
shift = 0
for _ in range(shift, input_padding_len):
target_logprobs.append([-float("inf")] * top_k) target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k))) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
for _ in range(target_seq_len): # for _ in range(target_seq_len):
# TODO also check against sample["labels"] # # TODO also check against sample["labels"]
target_mask.append([1] * top_k) # target_mask.append([1] * top_k)
for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100:
target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs): for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids # Initialize collections for logprobs and token_ids
@@ -120,10 +141,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_logprobs.append(position_logprobs_scaled) target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids) target_token_ids.append(position_token_ids)
# since we started at index 1 for causal, we need one more padding token if shift == 1:
target_logprobs.append([-float("inf")] * top_k) # since we started at index 1 for causal, we need one more padding token
target_token_ids.append(list(range(top_k))) target_logprobs.append([-float("inf")] * top_k)
target_mask.append([0] * top_k) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
# Update sample with transformed logprobs # Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs sample["target_logprobs"] = target_logprobs

View File

@@ -45,7 +45,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
inputs, inputs,
return_outputs=False, return_outputs=False,
num_items_in_batch=None, num_items_in_batch=None,
shift_targets=False, shift_targets=True,
): ):
""" """
How the loss is computed by Trainer. By default, all models return the loss in the first element. How the loss is computed by Trainer. By default, all models return the loss in the first element.

View File

@@ -279,6 +279,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long_kwargs["desc"] = "Dropping Long Sequences" drop_long_kwargs["desc"] = "Dropping Long Sequences"
train_dataset = train_dataset.filter( train_dataset = train_dataset.filter(
drop_long, drop_long,
batched=True,
**filter_map_kwargs, **filter_map_kwargs,
**drop_long_kwargs, **drop_long_kwargs,
) )
@@ -310,8 +311,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
""" """
labels = sample["labels"] labels = sample["labels"]
if not labels: if not labels:
# Edge case: if labels is empty, decide if you want to keep or drop return True
return True # or False
# Check if single example or batch # Check if single example or batch
# If first element is an int, we assume a single example # If first element is an int, we assume a single example