fix check

This commit is contained in:
Wing Lian
2025-05-26 21:32:11 -04:00
parent c7b1db329e
commit b75db13615

View File

@@ -1,9 +1,13 @@
"""
Packed data loader for online teacher training supporting vllm and sglang.
"""
import logging
from typing import Any, Dict, List, Optional
import pandas as pd import pandas as pd
import requests import requests
import logging
from typing import List, Optional, Dict, Any
import torch import torch
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import retry_on_request_exceptions from axolotl.utils.data.utils import retry_on_request_exceptions
@@ -11,6 +15,9 @@ LOG = logging.getLogger(__name__)
class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
"""
Collator for online teacher training.
"""
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100 DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
def __init__( def __init__(
@@ -25,11 +32,15 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if kd_online_server_base_url is None: if kd_online_server_base_url is None:
raise ValueError("kd_online_server_base_url must be provided for OnlineTeacherDataloader") raise ValueError(
"kd_online_server_base_url must be provided for OnlineTeacherDataloader"
)
if kd_online_topk is None or kd_online_topk <= 0: if kd_online_topk is None or kd_online_topk <= 0:
raise ValueError("kd_online_topk must be a positive integer for OnlineTeacherDataloader") raise ValueError(
"kd_online_topk must be a positive integer for OnlineTeacherDataloader"
)
self.kd_online_server_base_url = kd_online_server_base_url.rstrip('/') self.kd_online_server_base_url = kd_online_server_base_url.rstrip("/")
self.kd_online_topk = kd_online_topk self.kd_online_topk = kd_online_topk
self.kd_temperature = kd_temperature self.kd_temperature = kd_temperature
self.kd_online_server = kd_online_server self.kd_online_server = kd_online_server
@@ -40,7 +51,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs. Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
""" """
if not raw_logprobs or self.kd_online_topk == 0: if not raw_logprobs or self.kd_online_topk == 0:
return [-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else [] return (
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
)
# Ensure raw_logprobs matches kd_online_topk length for tensor operations # Ensure raw_logprobs matches kd_online_topk length for tensor operations
# This should ideally be handled by the caller ensuring correct padding/truncation first # This should ideally be handled by the caller ensuring correct padding/truncation first
@@ -50,9 +63,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
f"Logprobs length mismatch in _normalize_logprobs. " f"Logprobs length mismatch in _normalize_logprobs. "
f"Expected {self.kd_online_topk}, got {len(raw_logprobs)}. Will pad/truncate." f"Expected {self.kd_online_topk}, got {len(raw_logprobs)}. Will pad/truncate."
) )
padded_logprobs = raw_logprobs[:self.kd_online_topk] padded_logprobs = raw_logprobs[: self.kd_online_topk]
if len(padded_logprobs) < self.kd_online_topk: if len(padded_logprobs) < self.kd_online_topk:
padded_logprobs.extend([-float("inf")] * (self.kd_online_topk - len(padded_logprobs))) padded_logprobs.extend(
[-float("inf")] * (self.kd_online_topk - len(padded_logprobs))
)
raw_logprobs = padded_logprobs raw_logprobs = padded_logprobs
try: try:
@@ -60,19 +75,27 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# Convert logprobs at T_online to probabilities # Convert logprobs at T_online to probabilities
# use log sum exp trick to avoid underflow # use log sum exp trick to avoid underflow
position_logprobs_lse = torch.logsumexp(position_logprobs_tensor, dim=-1, keepdim=True) position_logprobs_lse = torch.logsumexp(
teacher_probs_t_online = torch.exp(position_logprobs_tensor - position_logprobs_lse) position_logprobs_tensor, dim=-1, keepdim=True
)
teacher_probs_t_online = torch.exp(
position_logprobs_tensor - position_logprobs_lse
)
# Normalize probabilities (sum to 1) # Normalize probabilities (sum to 1)
# This is important if the top-k from server aren't a full distribution # This is important if the top-k from server aren't a full distribution
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=0, keepdim=True) teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=0, keepdim=True)
if teacher_probs_t_online_sum.item() > 1e-9: if teacher_probs_t_online_sum.item() > 1e-9:
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum teacher_probs_t_online = (
teacher_probs_t_online / teacher_probs_t_online_sum
)
else: else:
# If sum is zero, create uniform distribution to avoid NaN/Inf later # If sum is zero, create uniform distribution to avoid NaN/Inf later
# This can happen if all raw_logprobs are -inf # This can happen if all raw_logprobs are -inf
if self.kd_online_topk > 0: if self.kd_online_topk > 0:
teacher_probs_t_online = torch.ones_like(teacher_probs_t_online) / self.kd_online_topk teacher_probs_t_online = (
torch.ones_like(teacher_probs_t_online) / self.kd_online_topk
)
# else: leave as is, will result in -inf logprobs # else: leave as is, will result in -inf logprobs
# #
# teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum( # teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum(
@@ -83,12 +106,17 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
return final_logprobs_tensor.tolist() return final_logprobs_tensor.tolist()
except Exception as e: except Exception as e:
LOG.error(f"Error during online logprob scaling: {e}. Returning raw logprobs.", exc_info=True) LOG.error(
f"Error during online logprob scaling: {e}. Returning raw logprobs.",
exc_info=True,
)
# Fallback to (padded/truncated) raw logprobs if scaling fails # Fallback to (padded/truncated) raw logprobs if scaling fails
return raw_logprobs return raw_logprobs
@retry_on_request_exceptions(max_retries=10, delay=5) @retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_sglang(self, batch_input_ids: List[List[int]], labels: List[List[int]]): def fetch_online_logprobs_sglang(
self, batch_input_ids: List[List[int]], labels: List[List[int]]
):
""" """
Fetches logprobs from an online teacher served by vllm for a batch of input_ids. Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys. Assumes API returns token IDs as strings in logprob dictionary keys.
@@ -130,69 +158,96 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# Return empty data; items processed later will get default empty KD fields # Return empty data; items processed later will get default empty KD fields
return ret_logprobs_data return ret_logprobs_data
for sequence_data, seq_input_ids, seq_labels in zip(
for sequence_data, seq_input_ids, seq_labels in zip(api_data, batch_input_ids, labels): api_data, batch_input_ids, labels
):
current_target_logprobs = [] current_target_logprobs = []
current_target_token_ids = [] current_target_token_ids = []
current_target_mask = [] current_target_mask = []
meta_info = sequence_data.pop("meta_info", {}) meta_info = sequence_data.pop("meta_info", {})
# Ensure input_top_logprobs is a list # Ensure input_top_logprobs is a list
input_top_logprobs: Optional[list[None |list[tuple]]] = meta_info.pop("input_top_logprobs", []) input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop(
"input_top_logprobs", []
)
if not isinstance(input_top_logprobs, list): if not isinstance(input_top_logprobs, list):
LOG.warning(f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.") LOG.warning(
input_top_logprobs = [] # Treat as empty f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
)
input_top_logprobs = [] # Treat as empty
# basic check that the logprob data len matches the input len, so no need to handle padding # basic check that the logprob data len matches the input len, so no need to handle padding
assert len(seq_input_ids) == len(input_top_logprobs) assert len(seq_input_ids) == len(input_top_logprobs)
for i, input_id, label in zip(range(len(seq_input_ids)), seq_input_ids, seq_labels): for i, input_id, label in zip(
range(len(seq_input_ids)), seq_input_ids, seq_labels
):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None: if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token. # this is always the case for the first token.
# there is never logprob data for the first token since that's a true input # there is never logprob data for the first token since that's a true input
# so we replace the None value with padding data # so we replace the None value with padding data
current_target_logprobs.append([-float("inf")] * self.kd_online_topk) current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk) current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk) current_target_mask.append([0] * self.kd_online_topk)
elif i < len(input_top_logprobs) and input_top_logprobs[i] is not None: elif (
i < len(input_top_logprobs)
and input_top_logprobs[i] is not None
):
pos_top_logprobs_data = input_top_logprobs[i] pos_top_logprobs_data = input_top_logprobs[i]
# Ensure pos_top_logprobs_data is a list of lists as expected # Ensure pos_top_logprobs_data is a list of lists as expected
if not (isinstance(pos_top_logprobs_data, list) and \ if not (
all(isinstance(item, list) for item in pos_top_logprobs_data) and \ isinstance(pos_top_logprobs_data, list)
len(pos_top_logprobs_data) > 0 and len(pos_top_logprobs_data[0]) == 3): # [logprob, token_id, token_str] and all(
LOG.warning(f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.") isinstance(item, list) for item in pos_top_logprobs_data
current_target_logprobs.append([-float("inf")] * self.kd_online_topk) )
and len(pos_top_logprobs_data) > 0
and len(pos_top_logprobs_data[0]) == 3
): # [logprob, token_id, token_str]
LOG.warning(
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk) current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk) current_target_mask.append([0] * self.kd_online_topk)
continue continue
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
pos_logprobs_raw, pos_token_ids, _ = [list(row) for row in zip(*pos_top_logprobs_data)] pos_logprobs_raw, pos_token_ids, _ = [
list(row) for row in zip(*pos_top_logprobs_data)
]
# Ensure correct length (top_k) # Ensure correct length (top_k)
if len(pos_logprobs_raw) < self.kd_online_topk: if len(pos_logprobs_raw) < self.kd_online_topk:
pad_len = self.kd_online_topk - len(pos_logprobs_raw) pad_len = self.kd_online_topk - len(pos_logprobs_raw)
pos_logprobs_raw.extend([-float("inf")] * pad_len) pos_logprobs_raw.extend([-float("inf")] * pad_len)
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
# truncate to top_k in case the response was longer # truncate to top_k in case the response was longer
current_target_token_ids.append(pos_token_ids[:self.kd_online_topk]) current_target_token_ids.append(
scaled_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk]) pos_token_ids[: self.kd_online_topk]
)
scaled_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(scaled_logprobs_for_position) current_target_logprobs.append(scaled_logprobs_for_position)
# Mask depends on the corresponding label for the student # Mask depends on the corresponding label for the student
label_for_pos = seq_labels[i] if i < len(seq_labels) else self.DEFAULT_LABEL_PAD_TOKEN_ID if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
if label_for_pos == self.DEFAULT_LABEL_PAD_TOKEN_ID:
current_target_mask.append([0] * self.kd_online_topk) current_target_mask.append([0] * self.kd_online_topk)
else: else:
current_target_mask.append([1] * self.kd_online_topk) current_target_mask.append([1] * self.kd_online_topk)
else: else:
# Pad if no logprobs for this position (either due to length mismatch or None entry) # Pad if no logprobs for this position (either due to length mismatch or None entry)
current_target_logprobs.append([-float("inf")] * self.kd_online_topk) current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk) current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk) current_target_mask.append([0] * self.kd_online_topk)
ret_logprobs_data["target_token_ids"].append(current_target_token_ids) ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
ret_logprobs_data["target_logprobs"].append(current_target_logprobs) ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
ret_logprobs_data["target_mask"].append(current_target_mask) ret_logprobs_data["target_mask"].append(current_target_mask)
@@ -201,8 +256,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
LOG.error(f"Error fetching logprobs from online teacher: {e}") LOG.error(f"Error fetching logprobs from online teacher: {e}")
raise e raise e
# ret_logprobs_data will be returned with empty lists, handled by the caller. # ret_logprobs_data will be returned with empty lists, handled by the caller.
except Exception as e: # Catch other potential errors during processing except Exception as e: # Catch other potential errors during processing
LOG.error(f"Unexpected error processing API response in fetch_online_logprobs: {e}", exc_info=True) LOG.error(
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
exc_info=True,
)
raise e raise e
# Return initialized empty data # Return initialized empty data
# return { # return {
@@ -211,11 +269,12 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# "target_mask": [], # "target_mask": [],
# } # }
return ret_logprobs_data return ret_logprobs_data
@retry_on_request_exceptions(max_retries=10, delay=5) @retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_vllm(self, batch_input_ids: List[List[int]], labels: List[List[int]]): def fetch_online_logprobs_vllm(
self, batch_input_ids: List[List[int]], labels: List[List[int]]
):
""" """
Fetches logprobs from an online teacher served by vllm for a batch of input_ids. Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys. Assumes API returns token IDs as strings in logprob dictionary keys.
@@ -258,7 +317,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# Return empty data; items processed later will get default empty KD fields # Return empty data; items processed later will get default empty KD fields
return ret_logprobs_data return ret_logprobs_data
for sequence_data, seq_input_ids, seq_labels in zip(choices, batch_input_ids, labels): for sequence_data, seq_input_ids, seq_labels in zip(
choices, batch_input_ids, labels
):
# seq_input_ids: List[int] # seq_input_ids: List[int]
# seq_labels: List[int] # seq_labels: List[int]
@@ -267,7 +328,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
current_target_mask = [] current_target_mask = []
# Ensure input_top_logprobs is a list # Ensure input_top_logprobs is a list
input_top_logprobs: Optional[list[None | list[tuple]]] = sequence_data.pop("prompt_logprobs", []) input_top_logprobs: Optional[list[None | list[tuple]]] = (
sequence_data.pop("prompt_logprobs", [])
)
""" """
vllm api data for prompt logprobs looks like: vllm api data for prompt logprobs looks like:
"prompt_logprobs": [ "prompt_logprobs": [
@@ -289,8 +352,10 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
} }
""" """
if not isinstance(input_top_logprobs, list): if not isinstance(input_top_logprobs, list):
LOG.warning(f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.") LOG.warning(
input_top_logprobs = [] # Treat as empty f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
)
input_top_logprobs = [] # Treat as empty
# basic check that the logprob data len matches the input len, so no need to handle padding # basic check that the logprob data len matches the input len, so no need to handle padding
assert len(seq_input_ids) == len(input_top_logprobs) assert len(seq_input_ids) == len(input_top_logprobs)
@@ -298,23 +363,43 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# generate a hash over seq_input_ids and convert it to an int # generate a hash over seq_input_ids and convert it to an int
hash_input_ids: int = hash(tuple(seq_input_ids)) hash_input_ids: int = hash(tuple(seq_input_ids))
for i, input_id, label in zip(range(len(seq_input_ids)), seq_input_ids, seq_labels): for i, _, label in zip(
range(len(seq_input_ids)), seq_input_ids, seq_labels
):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None: if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token. # this is always the case for the first token.
# there is never logprob data for the first token since that's a true input # there is never logprob data for the first token since that's a true input
# so we replace the None value with padding data # so we replace the None value with padding data
current_target_logprobs.append([-float("inf")] * self.kd_online_topk) current_target_logprobs.append(
current_target_token_ids.append(list(range(self.kd_online_topk))) [-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk) current_target_mask.append([0] * self.kd_online_topk)
elif i < len(input_top_logprobs) and input_top_logprobs[i] is not None: elif (
i < len(input_top_logprobs)
and input_top_logprobs[i] is not None
):
pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i]
# Ensure pos_top_logprobs_data is a list of lists as expected # Ensure pos_top_logprobs_data is a list of lists as expected
if not (isinstance(pos_top_logprobs_data, dict) and \ if not (
all(isinstance(item, dict) for item in pos_top_logprobs_data.values()) and \ isinstance(pos_top_logprobs_data, dict)
len(pos_top_logprobs_data.keys()) > 0): # [logprob, token_id, token_str] and all(
LOG.warning(f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.") isinstance(item, dict)
current_target_logprobs.append([-float("inf")] * self.kd_online_topk) for item in pos_top_logprobs_data.values()
current_target_token_ids.append(list(range(self.kd_online_topk))) )
and len(pos_top_logprobs_data.keys()) > 0
): # [logprob, token_id, token_str]
LOG.warning(
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk) current_target_mask.append([0] * self.kd_online_topk)
continue continue
@@ -322,23 +407,31 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
pos_token_ids = pos_top_logprobs_data.keys() pos_token_ids = pos_top_logprobs_data.keys()
pos_logprobs_dict = pos_top_logprobs_data.values() pos_logprobs_dict = pos_top_logprobs_data.values()
pos_token_ids = [int(token_id) for token_id in pos_token_ids] pos_token_ids = [int(token_id) for token_id in pos_token_ids]
pos_logprobs_raw = [float(logprob.get("logprob", -float("inf"))) for logprob in pos_logprobs_dict] pos_logprobs_raw = [
float(logprob.get("logprob", -float("inf")))
for logprob in pos_logprobs_dict
]
# Ensure correct length (top_k) # Ensure correct length (top_k)
if len(pos_logprobs_raw) < self.kd_online_topk: if len(pos_logprobs_raw) < self.kd_online_topk:
pad_len = self.kd_online_topk - len(pos_logprobs_raw) pad_len = self.kd_online_topk - len(pos_logprobs_raw)
LOG.warning(f"Padding position {i} with {pad_len} top-k tokens and logprobs.") LOG.warning(
f"Padding position {i} with {pad_len} top-k tokens and logprobs."
)
pos_logprobs_raw.extend([-float("inf")] * pad_len) pos_logprobs_raw.extend([-float("inf")] * pad_len)
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
# truncate to top_k in case the response was longer # truncate to top_k in case the response was longer
current_target_token_ids.append(pos_token_ids[:self.kd_online_topk]) current_target_token_ids.append(
pos_token_ids[: self.kd_online_topk]
)
# normalized_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk]) # normalized_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk])
# current_target_logprobs.append(normalized_logprobs_for_position) # current_target_logprobs.append(normalized_logprobs_for_position)
# don't normalize for now as the probs seem to sum to 1.0 already # don't normalize for now as the probs seem to sum to 1.0 already
current_target_logprobs.append(pos_logprobs_raw[:self.kd_online_topk]) current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
# Mask depends on the corresponding label for the student # Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID: if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
@@ -347,8 +440,12 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
current_target_mask.append([1] * self.kd_online_topk) current_target_mask.append([1] * self.kd_online_topk)
else: else:
# Pad if no logprobs for this position (either due to length mismatch or None entry) # Pad if no logprobs for this position (either due to length mismatch or None entry)
current_target_logprobs.append([-float("inf")] * self.kd_online_topk) current_target_logprobs.append(
current_target_token_ids.append(list(range(self.kd_online_topk))) [-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk) current_target_mask.append([0] * self.kd_online_topk)
ret_logprobs_data["target_token_ids"].append(current_target_token_ids) ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
@@ -364,18 +461,24 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
LOG.error(f"Error fetching logprobs from online teacher: {e}") LOG.error(f"Error fetching logprobs from online teacher: {e}")
raise e raise e
# ret_logprobs_data will be returned with empty lists, handled by the caller. # ret_logprobs_data will be returned with empty lists, handled by the caller.
except Exception as e: # Catch other potential errors during processing except Exception as e: # Catch other potential errors during processing
LOG.error(f"Unexpected error processing API response in fetch_online_logprobs: {e}", exc_info=True) LOG.error(
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
exc_info=True,
)
raise e raise e
return ret_logprobs_data return ret_logprobs_data
def __call__(self, features: List[List[Dict[str, Any]]], def __call__(
return_tensors: Optional[str] = None) -> Dict[str, Any]: self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None
) -> Dict[str, Any]:
if not features: if not features:
return super().__call__(features, return_tensors=return_tensors) return super().__call__(features, return_tensors=return_tensors)
for sub_batch_features in features: # sub_batch_features is List[Dict[str, Any]] for (
sub_batch_features
) in features: # sub_batch_features is List[Dict[str, Any]]
if not sub_batch_features: if not sub_batch_features:
continue continue
@@ -386,7 +489,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
for item_dict in sub_batch_features: for item_dict in sub_batch_features:
if not isinstance(item_dict, dict): if not isinstance(item_dict, dict):
LOG.warning(f"Skipping non-dict item in sub_batch_features: {item_dict}") LOG.warning(
f"Skipping non-dict item in sub_batch_features: {item_dict}"
)
continue continue
current_input_ids = item_dict.get("input_ids") current_input_ids = item_dict.get("input_ids")
@@ -394,8 +499,16 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
if current_input_ids is not None and current_labels is not None: if current_input_ids is not None and current_labels is not None:
# Ensure input_ids and labels are lists of ints for JSON serialization # Ensure input_ids and labels are lists of ints for JSON serialization
input_ids_list = current_input_ids.tolist() if hasattr(current_input_ids, "tolist") else list(current_input_ids) input_ids_list = (
labels_list = current_labels.tolist() if hasattr(current_labels, "tolist") else list(current_labels) current_input_ids.tolist()
if hasattr(current_input_ids, "tolist")
else list(current_input_ids)
)
labels_list = (
current_labels.tolist()
if hasattr(current_labels, "tolist")
else list(current_labels)
)
input_ids_for_api_call.append(input_ids_list) input_ids_for_api_call.append(input_ids_list)
labels_for_api_call.append(labels_list) labels_for_api_call.append(labels_list)
@@ -408,29 +521,46 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
item_dict.setdefault("target_mask", []) item_dict.setdefault("target_mask", [])
# print(items_for_api_call) # print(items_for_api_call)
if items_for_api_call: # Only call API if there's something to process if items_for_api_call: # Only call API if there's something to process
if self.kd_online_server == "sglang": if self.kd_online_server == "sglang":
api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(input_ids_for_api_call, labels_for_api_call) api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(
input_ids_for_api_call, labels_for_api_call
)
else: else:
api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(input_ids_for_api_call, labels_for_api_call) api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(
input_ids_for_api_call, labels_for_api_call
)
# api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask" # api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask"
# Each value is a list, corresponding to items_for_api_call # Each value is a list, corresponding to items_for_api_call
for i, item_to_update in enumerate(items_for_api_call): for i, item_to_update in enumerate(items_for_api_call):
# TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly. # TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly.
if api_responses_for_sub_batch and \ if api_responses_for_sub_batch and i < len(
i < len(api_responses_for_sub_batch): # Check bounds api_responses_for_sub_batch["target_token_ids"]
assert len(api_responses_for_sub_batch["target_token_ids"][i]) == len(item_to_update["input_ids"]) ): # Check bounds
assert len(api_responses_for_sub_batch["target_logprobs"][i]) == len(item_to_update["input_ids"]) assert len(
assert len(api_responses_for_sub_batch["target_mask"][i]) == len(item_to_update["labels"]) api_responses_for_sub_batch["target_token_ids"][i]
item_to_update["target_token_ids"] = api_responses_for_sub_batch["target_token_ids"][i] ) == len(item_to_update["input_ids"])
item_to_update["target_logprobs"] = api_responses_for_sub_batch["target_logprobs"][i] assert len(
item_to_update["target_mask"] = api_responses_for_sub_batch["target_mask"][i] api_responses_for_sub_batch["target_logprobs"][i]
) == len(item_to_update["input_ids"])
assert len(
api_responses_for_sub_batch["target_mask"][i]
) == len(item_to_update["labels"])
item_to_update["target_token_ids"] = (
api_responses_for_sub_batch["target_token_ids"][i]
)
item_to_update["target_logprobs"] = api_responses_for_sub_batch[
"target_logprobs"
][i]
item_to_update["target_mask"] = api_responses_for_sub_batch[
"target_mask"
][i]
else: else:
# API call failed for this item, or response was shorter than expected. # API call failed for this item, or response was shorter than expected.
# Ensure KD fields are initialized as empty lists. # Ensure KD fields are initialized as empty lists.
LOG.warning( LOG.warning(
f"Failed to get online KD data for an item in the batch (index {i}), or API response was too short. " f" (index {i}), or API response was too short. "
f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}" f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}"
) )
item_to_update.setdefault("target_token_ids", []) item_to_update.setdefault("target_token_ids", [])