use max not min

This commit is contained in:
Wing Lian
2025-05-26 22:34:07 -04:00
parent 9eb53f5c9e
commit 90c7228ff9
3 changed files with 13 additions and 9 deletions

View File

@@ -55,6 +55,7 @@ class KDPlugin(BasePlugin):
"kd_online_topk": cfg.kd_online_topk, "kd_online_topk": cfg.kd_online_topk,
"kd_temperature": cfg.kd_temperature, "kd_temperature": cfg.kd_temperature,
"kd_online_server": cfg.kd_online_server, "kd_online_server": cfg.kd_online_server,
"kd_online_timeout": cfg.kd_online_timeout,
} }
if use_batch_sampler_collator: if use_batch_sampler_collator:

View File

@@ -41,3 +41,4 @@ class KDArgs(BaseModel):
kd_online_server_base_url: str | None = None kd_online_server_base_url: str | None = None
kd_online_topk: int | None = None kd_online_topk: int | None = None
kd_online_server: InferenceServerType | None = "vllm" kd_online_server: InferenceServerType | None = "vllm"
kd_online_timeout: int | None = 120

View File

@@ -1,6 +1,7 @@
""" """
Packed data loader for online teacher training supporting vllm and sglang. Packed data loader for online teacher training supporting vllm and sglang.
""" """
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@@ -18,6 +19,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
""" """
Collator for online teacher training. Collator for online teacher training.
""" """
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100 DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
def __init__( def __init__(
@@ -147,7 +149,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
} }
try: try:
response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout) response = self.http_session.post(
api_endpoint, json=payload, timeout=self.kd_online_timeout
)
response.raise_for_status() response.raise_for_status()
api_data: list[dict] = response.json() api_data: list[dict] = response.json()
@@ -299,7 +303,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
} }
try: try:
response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout) response = self.http_session.post(
api_endpoint, json=payload, timeout=self.kd_online_timeout
)
response.raise_for_status() response.raise_for_status()
api_data: dict = response.json() api_data: dict = response.json()
choices: list[dict] = api_data["choices"] choices: list[dict] = api_data["choices"]
@@ -358,9 +364,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
seq_len = len(seq_input_ids) seq_len = len(seq_input_ids)
for i, _, label in zip( for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels):
range(seq_len), seq_input_ids, seq_labels
):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None: if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token. # this is always the case for the first token.
# there is never logprob data for the first token since that's a true input # there is never logprob data for the first token since that's a true input
@@ -435,13 +439,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
list(range(self.kd_online_topk)) list(range(self.kd_online_topk))
) )
current_target_mask.append([0] * self.kd_online_topk) current_target_mask.append([0] * self.kd_online_topk)
for i in range(min(0, seq_len - len(current_target_logprobs))): for i in range(max(0, seq_len - len(current_target_logprobs))):
current_target_logprobs.append( current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk [-float("inf")] * self.kd_online_topk
) )
current_target_token_ids.append( current_target_token_ids.append(list(range(self.kd_online_topk)))
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk) current_target_mask.append([0] * self.kd_online_topk)
ret_logprobs_data["target_token_ids"].append(current_target_token_ids) ret_logprobs_data["target_token_ids"].append(current_target_token_ids)