From 9eb53f5c9ed5218542220691ddc93e362e557dfb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 26 May 2025 22:28:26 -0400 Subject: [PATCH] fix length of padding --- src/axolotl/integrations/kd/collator_online_teacher.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/axolotl/integrations/kd/collator_online_teacher.py b/src/axolotl/integrations/kd/collator_online_teacher.py index 521d9af4a..30c98bb4b 100644 --- a/src/axolotl/integrations/kd/collator_online_teacher.py +++ b/src/axolotl/integrations/kd/collator_online_teacher.py @@ -356,9 +356,6 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): # basic check that the logprob data len matches the input len, so no need to handle padding assert len(seq_input_ids) == len(input_top_logprobs) - # generate a hash over seq_input_ids and convert it to an int - # hash_input_ids: int = hash(tuple(seq_input_ids)) - seq_len = len(seq_input_ids) for i, _, label in zip( @@ -438,7 +435,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): list(range(self.kd_online_topk)) ) current_target_mask.append([0] * self.kd_online_topk) - for i in range(len(current_target_logprobs) - seq_len): + for i in range(min(0, seq_len - len(current_target_logprobs))): current_target_logprobs.append( [-float("inf")] * self.kd_online_topk ) @@ -451,6 +448,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): ret_logprobs_data["target_logprobs"].append(current_target_logprobs) ret_logprobs_data["target_mask"].append(current_target_mask) + # TODO save and load targets to disk for caching for next epoch + # generate a hash over seq_input_ids and convert it to an int + # hash_input_ids: int = hash(tuple(seq_input_ids)) # with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f: # pd.DataFrame(current_target_logprobs).to_parquet(f, index=False) # with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f: