use max not min
This commit is contained in:
@@ -55,6 +55,7 @@ class KDPlugin(BasePlugin):
|
|||||||
"kd_online_topk": cfg.kd_online_topk,
|
"kd_online_topk": cfg.kd_online_topk,
|
||||||
"kd_temperature": cfg.kd_temperature,
|
"kd_temperature": cfg.kd_temperature,
|
||||||
"kd_online_server": cfg.kd_online_server,
|
"kd_online_server": cfg.kd_online_server,
|
||||||
|
"kd_online_timeout": cfg.kd_online_timeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
if use_batch_sampler_collator:
|
if use_batch_sampler_collator:
|
||||||
|
|||||||
@@ -41,3 +41,4 @@ class KDArgs(BaseModel):
|
|||||||
kd_online_server_base_url: str | None = None
|
kd_online_server_base_url: str | None = None
|
||||||
kd_online_topk: int | None = None
|
kd_online_topk: int | None = None
|
||||||
kd_online_server: InferenceServerType | None = "vllm"
|
kd_online_server: InferenceServerType | None = "vllm"
|
||||||
|
kd_online_timeout: int | None = 120
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Packed data loader for online teacher training supporting vllm and sglang.
|
Packed data loader for online teacher training supporting vllm and sglang.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
@@ -18,6 +19,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
"""
|
"""
|
||||||
Collator for online teacher training.
|
Collator for online teacher training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
|
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -147,7 +149,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout)
|
response = self.http_session.post(
|
||||||
|
api_endpoint, json=payload, timeout=self.kd_online_timeout
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
api_data: list[dict] = response.json()
|
api_data: list[dict] = response.json()
|
||||||
|
|
||||||
@@ -299,7 +303,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout)
|
response = self.http_session.post(
|
||||||
|
api_endpoint, json=payload, timeout=self.kd_online_timeout
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
api_data: dict = response.json()
|
api_data: dict = response.json()
|
||||||
choices: list[dict] = api_data["choices"]
|
choices: list[dict] = api_data["choices"]
|
||||||
@@ -358,9 +364,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
seq_len = len(seq_input_ids)
|
seq_len = len(seq_input_ids)
|
||||||
|
|
||||||
for i, _, label in zip(
|
for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels):
|
||||||
range(seq_len), seq_input_ids, seq_labels
|
|
||||||
):
|
|
||||||
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
||||||
# this is always the case for the first token.
|
# this is always the case for the first token.
|
||||||
# there is never logprob data for the first token since that's a true input
|
# there is never logprob data for the first token since that's a true input
|
||||||
@@ -435,13 +439,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
list(range(self.kd_online_topk))
|
list(range(self.kd_online_topk))
|
||||||
)
|
)
|
||||||
current_target_mask.append([0] * self.kd_online_topk)
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
for i in range(min(0, seq_len - len(current_target_logprobs))):
|
for i in range(max(0, seq_len - len(current_target_logprobs))):
|
||||||
current_target_logprobs.append(
|
current_target_logprobs.append(
|
||||||
[-float("inf")] * self.kd_online_topk
|
[-float("inf")] * self.kd_online_topk
|
||||||
)
|
)
|
||||||
current_target_token_ids.append(
|
current_target_token_ids.append(list(range(self.kd_online_topk)))
|
||||||
list(range(self.kd_online_topk))
|
|
||||||
)
|
|
||||||
current_target_mask.append([0] * self.kd_online_topk)
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
|
|
||||||
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
|
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user