handle input only for online
This commit is contained in:
@@ -53,6 +53,7 @@ class KDArgs(BaseModel):
|
|||||||
kd_online_server: InferenceServerType | None = Field(
|
kd_online_server: InferenceServerType | None = Field(
|
||||||
default_factory=lambda: InferenceServerType.vllm
|
default_factory=lambda: InferenceServerType.vllm
|
||||||
)
|
)
|
||||||
|
kd_online_server_model: str | None = None
|
||||||
kd_online_timeout: int | None = 120
|
kd_online_timeout: int | None = 120
|
||||||
kd_temperature_min: float | None = (
|
kd_temperature_min: float | None = (
|
||||||
None # kd temperature scheduling during online kd
|
None # kd temperature scheduling during online kd
|
||||||
|
|||||||
@@ -13,6 +13,12 @@ class ChatTemplateStrategyWithOnlineKD(ChatTemplateStrategy):
|
|||||||
# batching doesn't work well for logprob data
|
# batching doesn't work well for logprob data
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _get_messages(self, prompt):
|
||||||
|
input_prompt = prompt.get("problem")
|
||||||
|
return [
|
||||||
|
{"role": "user", "content": input_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
def _tokenize_single_prompt(self, prompt):
|
def _tokenize_single_prompt(self, prompt):
|
||||||
turns = self.get_conversation_thread(prompt)
|
turns = self.get_conversation_thread(prompt)
|
||||||
tools = self._get_tools(prompt)
|
tools = self._get_tools(prompt)
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ class AxolotlOnlineKDTrainer(AxolotlKDTrainer):
|
|||||||
|
|
||||||
def get_teacher_logprobs(self, input_ids, labels):
|
def get_teacher_logprobs(self, input_ids, labels):
|
||||||
request_body = {
|
request_body = {
|
||||||
"model": "arcee-ai/Trinity-Large-Preview",
|
"model": self.axolotl_cfg.kd_online_server_model,
|
||||||
"prompt": input_ids,
|
"prompt": input_ids,
|
||||||
"logprobs": self.axolotl_cfg.kd_online_topk,
|
"logprobs": self.axolotl_cfg.kd_online_topk,
|
||||||
"echo": True,
|
"echo": True,
|
||||||
|
|||||||
Reference in New Issue
Block a user