more fixes

This commit is contained in:
Wing Lian
2025-01-14 21:37:10 -05:00
parent 510cf45317
commit 35a84f2cb8
2 changed files with 6 additions and 6 deletions

View File

@@ -71,7 +71,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
target_seq_len = input_seq_len
# target_seq_len = input_seq_len
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
@@ -84,9 +84,9 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for _ in range(target_seq_len):
# TODO also check against sample["labels"]
target_mask.append([1] * top_k)
# for _ in range(target_seq_len):
# # TODO also check against sample["labels"]
# target_mask.append([1] * top_k)
for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100:

View File

@@ -279,6 +279,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long_kwargs["desc"] = "Dropping Long Sequences"
train_dataset = train_dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
@@ -310,8 +311,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
"""
labels = sample["labels"]
if not labels:
# Edge case: if labels is empty, decide if you want to keep or drop
return True # or False
return True
# Check if single example or batch
# If first element is an int, we assume a single example