diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index 425d8ddf6..7cc7f1ed2 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -53,6 +53,7 @@ class KDArgs(BaseModel): kd_online_server: InferenceServerType | None = Field( default_factory=lambda: InferenceServerType.vllm ) + kd_online_server_model: str | None = None kd_online_timeout: int | None = 120 kd_temperature_min: float | None = ( None # kd temperature scheduling during online kd diff --git a/src/axolotl/integrations/kd/online_chat_template.py b/src/axolotl/integrations/kd/online_chat_template.py index 400c72a56..c176cf156 100644 --- a/src/axolotl/integrations/kd/online_chat_template.py +++ b/src/axolotl/integrations/kd/online_chat_template.py @@ -13,6 +13,12 @@ class ChatTemplateStrategyWithOnlineKD(ChatTemplateStrategy): # batching doesn't work well for logprob data return False + def _get_messages(self, prompt): + input_prompt = prompt.get("problem") + return [ + {"role": "user", "content": input_prompt}, + ] + def _tokenize_single_prompt(self, prompt): turns = self.get_conversation_thread(prompt) tools = self._get_tools(prompt) diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index cea864357..0a86fbe0b 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -196,7 +196,7 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer): def get_teacher_logprobs(self, input_ids, labels): request_body = { - "model": "arcee-ai/Trinity-Large-Preview", + "model": self.axolotl_cfg.kd_online_server_model, "prompt": input_ids, "logprobs": self.axolotl_cfg.kd_online_topk, "echo": True,