use max not min
This commit is contained in:
@@ -55,6 +55,7 @@ class KDPlugin(BasePlugin):
|
||||
"kd_online_topk": cfg.kd_online_topk,
|
||||
"kd_temperature": cfg.kd_temperature,
|
||||
"kd_online_server": cfg.kd_online_server,
|
||||
"kd_online_timeout": cfg.kd_online_timeout,
|
||||
}
|
||||
|
||||
if use_batch_sampler_collator:
|
||||
|
||||
@@ -41,3 +41,4 @@ class KDArgs(BaseModel):
|
||||
kd_online_server_base_url: str | None = None
|
||||
kd_online_topk: int | None = None
|
||||
kd_online_server: InferenceServerType | None = "vllm"
|
||||
kd_online_timeout: int | None = 120
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Packed data loader for online teacher training supporting vllm and sglang.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -18,6 +19,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
"""
|
||||
Collator for online teacher training.
|
||||
"""
|
||||
|
||||
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
|
||||
|
||||
def __init__(
|
||||
@@ -147,7 +149,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout)
|
||||
response = self.http_session.post(
|
||||
api_endpoint, json=payload, timeout=self.kd_online_timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
api_data: list[dict] = response.json()
|
||||
|
||||
@@ -299,7 +303,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout)
|
||||
response = self.http_session.post(
|
||||
api_endpoint, json=payload, timeout=self.kd_online_timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
api_data: dict = response.json()
|
||||
choices: list[dict] = api_data["choices"]
|
||||
@@ -358,9 +364,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
|
||||
seq_len = len(seq_input_ids)
|
||||
|
||||
for i, _, label in zip(
|
||||
range(seq_len), seq_input_ids, seq_labels
|
||||
):
|
||||
for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels):
|
||||
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
||||
# this is always the case for the first token.
|
||||
# there is never logprob data for the first token since that's a true input
|
||||
@@ -435,13 +439,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
list(range(self.kd_online_topk))
|
||||
)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
for i in range(min(0, seq_len - len(current_target_logprobs))):
|
||||
for i in range(max(0, seq_len - len(current_target_logprobs))):
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append(
|
||||
list(range(self.kd_online_topk))
|
||||
)
|
||||
current_target_token_ids.append(list(range(self.kd_online_topk)))
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
|
||||
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
|
||||
|
||||
Reference in New Issue
Block a user