shift off the first empty token
This commit is contained in:
@@ -27,6 +27,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
kd_online_topk: Optional[int] = None,
|
||||
kd_temperature: Optional[float] = 1.0,
|
||||
kd_online_server: Optional[str] = "vllm",
|
||||
kd_online_timeout: Optional[int] = 120,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -45,6 +46,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
self.kd_temperature = kd_temperature
|
||||
self.kd_online_server = kd_online_server
|
||||
self.http_session = requests.Session()
|
||||
self.kd_online_timeout = kd_online_timeout
|
||||
|
||||
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
|
||||
"""
|
||||
@@ -145,7 +147,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.http_session.post(api_endpoint, json=payload, timeout=60)
|
||||
response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout)
|
||||
response.raise_for_status()
|
||||
api_data: list[dict] = response.json()
|
||||
|
||||
@@ -262,12 +264,6 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
exc_info=True,
|
||||
)
|
||||
raise e
|
||||
# Return initialized empty data
|
||||
# return {
|
||||
# "target_token_ids": [],
|
||||
# "target_logprobs": [],
|
||||
# "target_mask": [],
|
||||
# }
|
||||
|
||||
return ret_logprobs_data
|
||||
|
||||
@@ -303,7 +299,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.http_session.post(api_endpoint, json=payload, timeout=60)
|
||||
response = self.http_session.post(api_endpoint, json=payload, timeout=self.kd_online_timeout)
|
||||
response.raise_for_status()
|
||||
api_data: dict = response.json()
|
||||
choices: list[dict] = api_data["choices"]
|
||||
@@ -361,22 +357,17 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
assert len(seq_input_ids) == len(input_top_logprobs)
|
||||
|
||||
# generate a hash over seq_input_ids and convert it to an int
|
||||
hash_input_ids: int = hash(tuple(seq_input_ids))
|
||||
# hash_input_ids: int = hash(tuple(seq_input_ids))
|
||||
|
||||
seq_len = len(seq_input_ids)
|
||||
|
||||
for i, _, label in zip(
|
||||
range(len(seq_input_ids)), seq_input_ids, seq_labels
|
||||
range(seq_len), seq_input_ids, seq_labels
|
||||
):
|
||||
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
||||
# this is always the case for the first token.
|
||||
# there is never logprob data for the first token since that's a true input
|
||||
# so we replace the None value with padding data
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append(
|
||||
list(range(self.kd_online_topk))
|
||||
)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
continue
|
||||
elif (
|
||||
i < len(input_top_logprobs)
|
||||
and input_top_logprobs[i] is not None
|
||||
@@ -447,15 +438,23 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
list(range(self.kd_online_topk))
|
||||
)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
for i in range(len(current_target_logprobs) - seq_len):
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append(
|
||||
list(range(self.kd_online_topk))
|
||||
)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
|
||||
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
|
||||
ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
|
||||
ret_logprobs_data["target_mask"].append(current_target_mask)
|
||||
|
||||
with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f:
|
||||
pd.DataFrame(current_target_logprobs).to_parquet(f, index=False)
|
||||
with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f:
|
||||
pd.DataFrame(current_target_token_ids).to_parquet(f, index=False)
|
||||
# with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f:
|
||||
# pd.DataFrame(current_target_logprobs).to_parquet(f, index=False)
|
||||
# with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f:
|
||||
# pd.DataFrame(current_target_token_ids).to_parquet(f, index=False)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
||||
|
||||
Reference in New Issue
Block a user