Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
35a84f2cb8 more fixes 2025-01-14 22:47:49 -05:00
Wing Lian
510cf45317 improve logprob masking and shift in trainer 2025-01-14 22:47:48 -05:00
3 changed files with 33 additions and 11 deletions

View File

@@ -53,6 +53,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
)
def transform_logprobs(self, sample):
"""
Transform logprobs to target format for KD training
"""
logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
@@ -62,16 +66,33 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_token_ids = []
target_mask = []
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
# target_seq_len = input_seq_len
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
for _ in range(1, input_padding_len): # start at 1 since this is causal
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer
shift = 0
for _ in range(shift, input_padding_len):
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for _ in range(target_seq_len):
# TODO also check against sample["labels"]
target_mask.append([1] * top_k)
# for _ in range(target_seq_len):
# # TODO also check against sample["labels"]
# target_mask.append([1] * top_k)
for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100:
target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids
@@ -120,10 +141,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids)
# since we started at index 1 for causal, we need one more padding token
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
if shift == 1:
# since we started at index 1 for causal, we need one more padding token
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs

View File

@@ -45,7 +45,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
inputs,
return_outputs=False,
num_items_in_batch=None,
shift_targets=False,
shift_targets=True,
):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.

View File

@@ -279,6 +279,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long_kwargs["desc"] = "Dropping Long Sequences"
train_dataset = train_dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
@@ -310,8 +311,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
"""
labels = sample["labels"]
if not labels:
# Edge case: if labels is empty, decide if you want to keep or drop
return True # or False
return True
# Check if single example or batch
# If first element is an int, we assume a single example