fix length of padding
This commit is contained in:
@@ -356,9 +356,6 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
# basic check that the logprob data len matches the input len, so no need to handle padding
|
||||
assert len(seq_input_ids) == len(input_top_logprobs)
|
||||
|
||||
# generate a hash over seq_input_ids and convert it to an int
|
||||
# hash_input_ids: int = hash(tuple(seq_input_ids))
|
||||
|
||||
seq_len = len(seq_input_ids)
|
||||
|
||||
for i, _, label in zip(
|
||||
@@ -438,7 +435,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
list(range(self.kd_online_topk))
|
||||
)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
for i in range(len(current_target_logprobs) - seq_len):
|
||||
for i in range(min(0, seq_len - len(current_target_logprobs))):
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
@@ -451,6 +448,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
|
||||
ret_logprobs_data["target_mask"].append(current_target_mask)
|
||||
|
||||
# TODO save and load targets to disk for caching for next epoch
|
||||
# generate a hash over seq_input_ids and convert it to an int
|
||||
# hash_input_ids: int = hash(tuple(seq_input_ids))
|
||||
# with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f:
|
||||
# pd.DataFrame(current_target_logprobs).to_parquet(f, index=False)
|
||||
# with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f:
|
||||
|
||||
Reference in New Issue
Block a user