fix length of padding

This commit is contained in:
Wing Lian
2025-05-26 22:28:26 -04:00
parent 225b420dc5
commit 9eb53f5c9e

View File

@@ -356,9 +356,6 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# basic check that the logprob data len matches the input len, so no need to handle padding
assert len(seq_input_ids) == len(input_top_logprobs)
# generate a hash over seq_input_ids and convert it to an int
# hash_input_ids: int = hash(tuple(seq_input_ids))
seq_len = len(seq_input_ids)
for i, _, label in zip(
@@ -438,7 +435,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
for i in range(len(current_target_logprobs) - seq_len):
for i in range(min(0, seq_len - len(current_target_logprobs))):
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
@@ -451,6 +448,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
ret_logprobs_data["target_mask"].append(current_target_mask)
# TODO save and load targets to disk for caching for next epoch
# generate a hash over seq_input_ids and convert it to an int
# hash_input_ids: int = hash(tuple(seq_input_ids))
# with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f:
# pd.DataFrame(current_target_logprobs).to_parquet(f, index=False)
# with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f: