shift off the first empty token

This commit is contained in:
Wing Lian
2025-05-26 22:23:50 -04:00
parent b75db13615
commit 225b420dc5

View File

@@ -27,6 +27,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
kd_online_topk: Optional[int] = None,
kd_temperature: Optional[float] = 1.0,
kd_online_server: Optional[str] = "vllm",
kd_online_timeout: Optional[int] = 120,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
@@ -45,6 +46,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
self.kd_temperature = kd_temperature
self.kd_online_server = kd_online_server
self.http_session = requests.Session()
self.kd_online_timeout = kd_online_timeout
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
"""
@@ -145,7 +147,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
}
try:
response = self.http_session.post(api_endpoint, json=payload, timeout=60)
response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout)
response.raise_for_status()
api_data: list[dict] = response.json()
@@ -262,12 +264,6 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
exc_info=True,
)
raise e
# Return initialized empty data
# return {
# "target_token_ids": [],
# "target_logprobs": [],
# "target_mask": [],
# }
return ret_logprobs_data
@@ -303,7 +299,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
}
try:
response = self.http_session.post(api_endpoint, json=payload, timeout=60)
response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout)
response.raise_for_status()
api_data: dict = response.json()
choices: list[dict] = api_data["choices"]
@@ -361,22 +357,17 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
assert len(seq_input_ids) == len(input_top_logprobs)
# generate a hash over seq_input_ids and convert it to an int
hash_input_ids: int = hash(tuple(seq_input_ids))
# hash_input_ids: int = hash(tuple(seq_input_ids))
seq_len = len(seq_input_ids)
for i, _, label in zip(
range(len(seq_input_ids)), seq_input_ids, seq_labels
range(seq_len), seq_input_ids, seq_labels
):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token.
# there is never logprob data for the first token since that's a true input
# so we replace the None value with padding data
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
continue
elif (
i < len(input_top_logprobs)
and input_top_logprobs[i] is not None
@@ -447,15 +438,23 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
for i in range(len(current_target_logprobs) - seq_len):
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
ret_logprobs_data["target_mask"].append(current_target_mask)
with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f:
pd.DataFrame(current_target_logprobs).to_parquet(f, index=False)
with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f:
pd.DataFrame(current_target_token_ids).to_parquet(f, index=False)
# with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f:
# pd.DataFrame(current_target_logprobs).to_parquet(f, index=False)
# with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f:
# pd.DataFrame(current_target_token_ids).to_parquet(f, index=False)
except requests.exceptions.RequestException as e:
LOG.error(f"Error fetching logprobs from online teacher: {e}")