fix length of padding
This commit is contained in:
@@ -356,9 +356,6 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
# basic check that the logprob data len matches the input len, so no need to handle padding
|
# basic check that the logprob data len matches the input len, so no need to handle padding
|
||||||
assert len(seq_input_ids) == len(input_top_logprobs)
|
assert len(seq_input_ids) == len(input_top_logprobs)
|
||||||
|
|
||||||
# generate a hash over seq_input_ids and convert it to an int
|
|
||||||
# hash_input_ids: int = hash(tuple(seq_input_ids))
|
|
||||||
|
|
||||||
seq_len = len(seq_input_ids)
|
seq_len = len(seq_input_ids)
|
||||||
|
|
||||||
for i, _, label in zip(
|
for i, _, label in zip(
|
||||||
@@ -438,7 +435,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
list(range(self.kd_online_topk))
|
list(range(self.kd_online_topk))
|
||||||
)
|
)
|
||||||
current_target_mask.append([0] * self.kd_online_topk)
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
for i in range(len(current_target_logprobs) - seq_len):
|
for i in range(min(0, seq_len - len(current_target_logprobs))):
|
||||||
current_target_logprobs.append(
|
current_target_logprobs.append(
|
||||||
[-float("inf")] * self.kd_online_topk
|
[-float("inf")] * self.kd_online_topk
|
||||||
)
|
)
|
||||||
@@ -451,6 +448,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
|
ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
|
||||||
ret_logprobs_data["target_mask"].append(current_target_mask)
|
ret_logprobs_data["target_mask"].append(current_target_mask)
|
||||||
|
|
||||||
|
# TODO save and load targets to disk for caching for next epoch
|
||||||
|
# generate a hash over seq_input_ids and convert it to an int
|
||||||
|
# hash_input_ids: int = hash(tuple(seq_input_ids))
|
||||||
# with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f:
|
# with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f:
|
||||||
# pd.DataFrame(current_target_logprobs).to_parquet(f, index=False)
|
# pd.DataFrame(current_target_logprobs).to_parquet(f, index=False)
|
||||||
# with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f:
|
# with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f:
|
||||||
|
|||||||
Reference in New Issue
Block a user