diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py index c6ea9602f..e648bcd25 100644 --- a/src/axolotl/integrations/kd/__init__.py +++ b/src/axolotl/integrations/kd/__init__.py @@ -55,6 +55,7 @@ class KDPlugin(BasePlugin): "kd_online_topk": cfg.kd_online_topk, "kd_temperature": cfg.kd_temperature, "kd_online_server": cfg.kd_online_server, + "kd_online_timeout": cfg.kd_online_timeout, } if use_batch_sampler_collator: diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index 2029f6509..5c97e7bdd 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -41,3 +41,4 @@ class KDArgs(BaseModel): kd_online_server_base_url: str | None = None kd_online_topk: int | None = None kd_online_server: InferenceServerType | None = "vllm" + kd_online_timeout: int | None = 120 diff --git a/src/axolotl/integrations/kd/collator_online_teacher.py b/src/axolotl/integrations/kd/collator_online_teacher.py index 30c98bb4b..39f6d16fb 100644 --- a/src/axolotl/integrations/kd/collator_online_teacher.py +++ b/src/axolotl/integrations/kd/collator_online_teacher.py @@ -1,6 +1,7 @@ """ Packed data loader for online teacher training supporting vllm and sglang. """ + import logging from typing import Any, Dict, List, Optional @@ -18,6 +19,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): """ Collator for online teacher training. """ + DEFAULT_LABEL_PAD_TOKEN_ID: int = -100 def __init__( @@ -147,7 +149,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): } try: - response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout) + response = self.http_session.post( + api_endpoint, json=payload, timeout=self.kd_online_timeout + ) response.raise_for_status() api_data: list[dict] = response.json() @@ -299,7 +303,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): } try: - response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout) + response = self.http_session.post( + api_endpoint, json=payload, timeout=self.kd_online_timeout + ) response.raise_for_status() api_data: dict = response.json() choices: list[dict] = api_data["choices"] @@ -358,9 +364,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): seq_len = len(seq_input_ids) - for i, _, label in zip( - range(seq_len), seq_input_ids, seq_labels - ): + for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels): if i < len(input_top_logprobs) and input_top_logprobs[i] is None: # this is always the case for the first token. # there is never logprob data for the first token since that's a true input @@ -435,13 +439,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): list(range(self.kd_online_topk)) ) current_target_mask.append([0] * self.kd_online_topk) - for i in range(min(0, seq_len - len(current_target_logprobs))): + for i in range(max(0, seq_len - len(current_target_logprobs))): current_target_logprobs.append( [-float("inf")] * self.kd_online_topk ) - current_target_token_ids.append( - list(range(self.kd_online_topk)) - ) + current_target_token_ids.append(list(range(self.kd_online_topk))) current_target_mask.append([0] * self.kd_online_topk) ret_logprobs_data["target_token_ids"].append(current_target_token_ids)