diff --git a/src/axolotl/integrations/kd/collator_online_teacher.py b/src/axolotl/integrations/kd/collator_online_teacher.py index 503006f15..521d9af4a 100644 --- a/src/axolotl/integrations/kd/collator_online_teacher.py +++ b/src/axolotl/integrations/kd/collator_online_teacher.py @@ -27,6 +27,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): kd_online_topk: Optional[int] = None, kd_temperature: Optional[float] = 1.0, kd_online_server: Optional[str] = "vllm", + kd_online_timeout: Optional[int] = 120, **kwargs: Any, ): super().__init__(*args, **kwargs) @@ -45,6 +46,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): self.kd_temperature = kd_temperature self.kd_online_server = kd_online_server self.http_session = requests.Session() + self.kd_online_timeout = kd_online_timeout def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]: """ @@ -145,7 +147,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): } try: - response = self.http_session.post(api_endpoint, json=payload, timeout=60) + response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout) response.raise_for_status() api_data: list[dict] = response.json() @@ -262,12 +264,6 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): exc_info=True, ) raise e - # Return initialized empty data - # return { - # "target_token_ids": [], - # "target_logprobs": [], - # "target_mask": [], - # } return ret_logprobs_data @@ -303,7 +299,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): } try: - response = self.http_session.post(api_endpoint, json=payload, timeout=60) + response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout) response.raise_for_status() api_data: dict = response.json() choices: list[dict] = api_data["choices"] @@ -361,22 +357,17 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): assert len(seq_input_ids) == len(input_top_logprobs) # generate a hash over seq_input_ids and convert it to an int - hash_input_ids: int = hash(tuple(seq_input_ids)) + # hash_input_ids: int = hash(tuple(seq_input_ids)) + + seq_len = len(seq_input_ids) for i, _, label in zip( - range(len(seq_input_ids)), seq_input_ids, seq_labels + range(seq_len), seq_input_ids, seq_labels ): if i < len(input_top_logprobs) and input_top_logprobs[i] is None: # this is always the case for the first token. # there is never logprob data for the first token since that's a true input - # so we replace the None value with padding data - current_target_logprobs.append( - [-float("inf")] * self.kd_online_topk - ) - current_target_token_ids.append( - list(range(self.kd_online_topk)) - ) - current_target_mask.append([0] * self.kd_online_topk) + continue elif ( i < len(input_top_logprobs) and input_top_logprobs[i] is not None @@ -447,15 +438,23 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): list(range(self.kd_online_topk)) ) current_target_mask.append([0] * self.kd_online_topk) + for i in range(len(current_target_logprobs) - seq_len): + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append( + list(range(self.kd_online_topk)) + ) + current_target_mask.append([0] * self.kd_online_topk) ret_logprobs_data["target_token_ids"].append(current_target_token_ids) ret_logprobs_data["target_logprobs"].append(current_target_logprobs) ret_logprobs_data["target_mask"].append(current_target_mask) - with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f: - pd.DataFrame(current_target_logprobs).to_parquet(f, index=False) - with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f: - pd.DataFrame(current_target_token_ids).to_parquet(f, index=False) + # with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f: + # pd.DataFrame(current_target_logprobs).to_parquet(f, index=False) + # with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f: + # pd.DataFrame(current_target_token_ids).to_parquet(f, index=False) except requests.exceptions.RequestException as e: LOG.error(f"Error fetching logprobs from online teacher: {e}")