handle input only for online
This commit is contained in:
@@ -53,6 +53,7 @@ class KDArgs(BaseModel):
|
||||
kd_online_server: InferenceServerType | None = Field(
|
||||
default_factory=lambda: InferenceServerType.vllm
|
||||
)
|
||||
kd_online_server_model: str | None = None
|
||||
kd_online_timeout: int | None = 120
|
||||
kd_temperature_min: float | None = (
|
||||
None # kd temperature scheduling during online kd
|
||||
|
||||
@@ -13,6 +13,12 @@ class ChatTemplateStrategyWithOnlineKD(ChatTemplateStrategy):
|
||||
# batching doesn't work well for logprob data
|
||||
return False
|
||||
|
||||
def _get_messages(self, prompt):
|
||||
input_prompt = prompt.get("problem")
|
||||
return [
|
||||
{"role": "user", "content": input_prompt},
|
||||
]
|
||||
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
tools = self._get_tools(prompt)
|
||||
|
||||
@@ -196,7 +196,7 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer):
|
||||
|
||||
def get_teacher_logprobs(self, input_ids, labels):
|
||||
request_body = {
|
||||
"model": "arcee-ai/Trinity-Large-Preview",
|
||||
"model": self.axolotl_cfg.kd_online_server_model,
|
||||
"prompt": input_ids,
|
||||
"logprobs": self.axolotl_cfg.kd_online_topk,
|
||||
"echo": True,
|
||||
|
||||
Reference in New Issue
Block a user