handle input only for online

This commit is contained in:
Wing Lian
2026-02-04 10:53:10 -05:00
parent a0669335e2
commit 17b01bfe36
3 changed files with 8 additions and 1 deletions

View File

@@ -53,6 +53,7 @@ class KDArgs(BaseModel):
kd_online_server: InferenceServerType | None = Field(
default_factory=lambda: InferenceServerType.vllm
)
kd_online_server_model: str | None = None
kd_online_timeout: int | None = 120
kd_temperature_min: float | None = (
None # kd temperature scheduling during online kd

View File

@@ -13,6 +13,12 @@ class ChatTemplateStrategyWithOnlineKD(ChatTemplateStrategy):
# batching doesn't work well for logprob data
return False
def _get_messages(self, prompt):
input_prompt = prompt.get("problem")
return [
{"role": "user", "content": input_prompt},
]
def _tokenize_single_prompt(self, prompt):
turns = self.get_conversation_thread(prompt)
tools = self._get_tools(prompt)

View File

@@ -196,7 +196,7 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer):
def get_teacher_logprobs(self, input_ids, labels):
request_body = {
"model": "arcee-ai/Trinity-Large-Preview",
"model": self.axolotl_cfg.kd_online_server_model,
"prompt": input_ids,
"logprobs": self.axolotl_cfg.kd_online_topk,
"echo": True,