use max not min

This commit is contained in:
Wing Lian
2025-05-26 22:34:07 -04:00
parent 9eb53f5c9e
commit 90c7228ff9
3 changed files with 13 additions and 9 deletions

View File

@@ -55,6 +55,7 @@ class KDPlugin(BasePlugin):
"kd_online_topk": cfg.kd_online_topk,
"kd_temperature": cfg.kd_temperature,
"kd_online_server": cfg.kd_online_server,
"kd_online_timeout": cfg.kd_online_timeout,
}
if use_batch_sampler_collator:

View File

@@ -41,3 +41,4 @@ class KDArgs(BaseModel):
kd_online_server_base_url: str | None = None
kd_online_topk: int | None = None
kd_online_server: InferenceServerType | None = "vllm"
kd_online_timeout: int | None = 120

View File

@@ -1,6 +1,7 @@
"""
Packed data loader for online teacher training supporting vllm and sglang.
"""
import logging
from typing import Any, Dict, List, Optional
@@ -18,6 +19,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
"""
Collator for online teacher training.
"""
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
def __init__(
@@ -147,7 +149,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
}
try:
response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout)
response = self.http_session.post(
api_endpoint, json=payload, timeout=self.kd_online_timeout
)
response.raise_for_status()
api_data: list[dict] = response.json()
@@ -299,7 +303,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
}
try:
response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout)
response = self.http_session.post(
api_endpoint, json=payload, timeout=self.kd_online_timeout
)
response.raise_for_status()
api_data: dict = response.json()
choices: list[dict] = api_data["choices"]
@@ -358,9 +364,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
seq_len = len(seq_input_ids)
for i, _, label in zip(
range(seq_len), seq_input_ids, seq_labels
):
for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token.
# there is never logprob data for the first token since that's a true input
@@ -435,13 +439,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
for i in range(min(0, seq_len - len(current_target_logprobs))):
for i in range(max(0, seq_len - len(current_target_logprobs))):
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_token_ids.append(list(range(self.kd_online_topk)))
current_target_mask.append([0] * self.kd_online_topk)
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)