fix check

This commit is contained in:
Wing Lian
2025-05-26 21:32:11 -04:00
parent c7b1db329e
commit b75db13615

View File

@@ -1,9 +1,13 @@
"""
Packed data loader for online teacher training supporting vllm and sglang.
"""
import logging
from typing import Any, Dict, List, Optional
import pandas as pd
import requests
import logging
from typing import List, Optional, Dict, Any
import torch
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import retry_on_request_exceptions
@@ -11,6 +15,9 @@ LOG = logging.getLogger(__name__)
class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
"""
Collator for online teacher training.
"""
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
def __init__(
@@ -25,11 +32,15 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
super().__init__(*args, **kwargs)
if kd_online_server_base_url is None:
raise ValueError("kd_online_server_base_url must be provided for OnlineTeacherDataloader")
raise ValueError(
"kd_online_server_base_url must be provided for OnlineTeacherDataloader"
)
if kd_online_topk is None or kd_online_topk <= 0:
raise ValueError("kd_online_topk must be a positive integer for OnlineTeacherDataloader")
raise ValueError(
"kd_online_topk must be a positive integer for OnlineTeacherDataloader"
)
self.kd_online_server_base_url = kd_online_server_base_url.rstrip('/')
self.kd_online_server_base_url = kd_online_server_base_url.rstrip("/")
self.kd_online_topk = kd_online_topk
self.kd_temperature = kd_temperature
self.kd_online_server = kd_online_server
@@ -40,7 +51,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
"""
if not raw_logprobs or self.kd_online_topk == 0:
return [-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
return (
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
)
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
# This should ideally be handled by the caller ensuring correct padding/truncation first
@@ -50,9 +63,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
f"Logprobs length mismatch in _normalize_logprobs. "
f"Expected {self.kd_online_topk}, got {len(raw_logprobs)}. Will pad/truncate."
)
padded_logprobs = raw_logprobs[:self.kd_online_topk]
padded_logprobs = raw_logprobs[: self.kd_online_topk]
if len(padded_logprobs) < self.kd_online_topk:
padded_logprobs.extend([-float("inf")] * (self.kd_online_topk - len(padded_logprobs)))
padded_logprobs.extend(
[-float("inf")] * (self.kd_online_topk - len(padded_logprobs))
)
raw_logprobs = padded_logprobs
try:
@@ -60,19 +75,27 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# Convert logprobs at T_online to probabilities
# use log sum exp trick to avoid underflow
position_logprobs_lse = torch.logsumexp(position_logprobs_tensor, dim=-1, keepdim=True)
teacher_probs_t_online = torch.exp(position_logprobs_tensor - position_logprobs_lse)
position_logprobs_lse = torch.logsumexp(
position_logprobs_tensor, dim=-1, keepdim=True
)
teacher_probs_t_online = torch.exp(
position_logprobs_tensor - position_logprobs_lse
)
# Normalize probabilities (sum to 1)
# This is important if the top-k from server aren't a full distribution
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=0, keepdim=True)
if teacher_probs_t_online_sum.item() > 1e-9:
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
teacher_probs_t_online = (
teacher_probs_t_online / teacher_probs_t_online_sum
)
else:
# If sum is zero, create uniform distribution to avoid NaN/Inf later
# This can happen if all raw_logprobs are -inf
if self.kd_online_topk > 0:
teacher_probs_t_online = torch.ones_like(teacher_probs_t_online) / self.kd_online_topk
teacher_probs_t_online = (
torch.ones_like(teacher_probs_t_online) / self.kd_online_topk
)
# else: leave as is, will result in -inf logprobs
#
# teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum(
@@ -83,12 +106,17 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
return final_logprobs_tensor.tolist()
except Exception as e:
LOG.error(f"Error during online logprob scaling: {e}. Returning raw logprobs.", exc_info=True)
LOG.error(
f"Error during online logprob scaling: {e}. Returning raw logprobs.",
exc_info=True,
)
# Fallback to (padded/truncated) raw logprobs if scaling fails
return raw_logprobs
@retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_sglang(self, batch_input_ids: List[List[int]], labels: List[List[int]]):
def fetch_online_logprobs_sglang(
self, batch_input_ids: List[List[int]], labels: List[List[int]]
):
"""
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys.
@@ -130,69 +158,96 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# Return empty data; items processed later will get default empty KD fields
return ret_logprobs_data
for sequence_data, seq_input_ids, seq_labels in zip(api_data, batch_input_ids, labels):
for sequence_data, seq_input_ids, seq_labels in zip(
api_data, batch_input_ids, labels
):
current_target_logprobs = []
current_target_token_ids = []
current_target_mask = []
meta_info = sequence_data.pop("meta_info", {})
# Ensure input_top_logprobs is a list
input_top_logprobs: Optional[list[None |list[tuple]]] = meta_info.pop("input_top_logprobs", [])
input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop(
"input_top_logprobs", []
)
if not isinstance(input_top_logprobs, list):
LOG.warning(f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.")
input_top_logprobs = [] # Treat as empty
LOG.warning(
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
)
input_top_logprobs = [] # Treat as empty
# basic check that the logprob data len matches the input len, so no need to handle padding
assert len(seq_input_ids) == len(input_top_logprobs)
for i, input_id, label in zip(range(len(seq_input_ids)), seq_input_ids, seq_labels):
for i, input_id, label in zip(
range(len(seq_input_ids)), seq_input_ids, seq_labels
):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token.
# there is never logprob data for the first token since that's a true input
# so we replace the None value with padding data
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
elif i < len(input_top_logprobs) and input_top_logprobs[i] is not None:
elif (
i < len(input_top_logprobs)
and input_top_logprobs[i] is not None
):
pos_top_logprobs_data = input_top_logprobs[i]
# Ensure pos_top_logprobs_data is a list of lists as expected
if not (isinstance(pos_top_logprobs_data, list) and \
all(isinstance(item, list) for item in pos_top_logprobs_data) and \
len(pos_top_logprobs_data) > 0 and len(pos_top_logprobs_data[0]) == 3): # [logprob, token_id, token_str]
LOG.warning(f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.")
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
if not (
isinstance(pos_top_logprobs_data, list)
and all(
isinstance(item, list) for item in pos_top_logprobs_data
)
and len(pos_top_logprobs_data) > 0
and len(pos_top_logprobs_data[0]) == 3
): # [logprob, token_id, token_str]
LOG.warning(
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
continue
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
pos_logprobs_raw, pos_token_ids, _ = [list(row) for row in zip(*pos_top_logprobs_data)]
pos_logprobs_raw, pos_token_ids, _ = [
list(row) for row in zip(*pos_top_logprobs_data)
]
# Ensure correct length (top_k)
if len(pos_logprobs_raw) < self.kd_online_topk:
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
pos_logprobs_raw.extend([-float("inf")] * pad_len)
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
# truncate to top_k in case the response was longer
current_target_token_ids.append(pos_token_ids[:self.kd_online_topk])
scaled_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk])
current_target_token_ids.append(
pos_token_ids[: self.kd_online_topk]
)
scaled_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(scaled_logprobs_for_position)
# Mask depends on the corresponding label for the student
label_for_pos = seq_labels[i] if i < len(seq_labels) else self.DEFAULT_LABEL_PAD_TOKEN_ID
if label_for_pos == self.DEFAULT_LABEL_PAD_TOKEN_ID:
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
current_target_mask.append([0] * self.kd_online_topk)
else:
current_target_mask.append([1] * self.kd_online_topk)
else:
# Pad if no logprobs for this position (either due to length mismatch or None entry)
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
ret_logprobs_data["target_mask"].append(current_target_mask)
@@ -201,8 +256,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
LOG.error(f"Error fetching logprobs from online teacher: {e}")
raise e
# ret_logprobs_data will be returned with empty lists, handled by the caller.
except Exception as e: # Catch other potential errors during processing
LOG.error(f"Unexpected error processing API response in fetch_online_logprobs: {e}", exc_info=True)
except Exception as e: # Catch other potential errors during processing
LOG.error(
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
exc_info=True,
)
raise e
# Return initialized empty data
# return {
@@ -211,11 +269,12 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# "target_mask": [],
# }
return ret_logprobs_data
@retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_vllm(self, batch_input_ids: List[List[int]], labels: List[List[int]]):
def fetch_online_logprobs_vllm(
self, batch_input_ids: List[List[int]], labels: List[List[int]]
):
"""
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys.
@@ -258,7 +317,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# Return empty data; items processed later will get default empty KD fields
return ret_logprobs_data
for sequence_data, seq_input_ids, seq_labels in zip(choices, batch_input_ids, labels):
for sequence_data, seq_input_ids, seq_labels in zip(
choices, batch_input_ids, labels
):
# seq_input_ids: List[int]
# seq_labels: List[int]
@@ -267,7 +328,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
current_target_mask = []
# Ensure input_top_logprobs is a list
input_top_logprobs: Optional[list[None | list[tuple]]] = sequence_data.pop("prompt_logprobs", [])
input_top_logprobs: Optional[list[None | list[tuple]]] = (
sequence_data.pop("prompt_logprobs", [])
)
"""
vllm api data for prompt logprobs looks like:
"prompt_logprobs": [
@@ -289,8 +352,10 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
}
"""
if not isinstance(input_top_logprobs, list):
LOG.warning(f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.")
input_top_logprobs = [] # Treat as empty
LOG.warning(
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
)
input_top_logprobs = [] # Treat as empty
# basic check that the logprob data len matches the input len, so no need to handle padding
assert len(seq_input_ids) == len(input_top_logprobs)
@@ -298,23 +363,43 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# generate a hash over seq_input_ids and convert it to an int
hash_input_ids: int = hash(tuple(seq_input_ids))
for i, input_id, label in zip(range(len(seq_input_ids)), seq_input_ids, seq_labels):
for i, _, label in zip(
range(len(seq_input_ids)), seq_input_ids, seq_labels
):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token.
# there is never logprob data for the first token since that's a true input
# so we replace the None value with padding data
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
current_target_token_ids.append(list(range(self.kd_online_topk)))
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
elif i < len(input_top_logprobs) and input_top_logprobs[i] is not None:
elif (
i < len(input_top_logprobs)
and input_top_logprobs[i] is not None
):
pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i]
# Ensure pos_top_logprobs_data is a list of lists as expected
if not (isinstance(pos_top_logprobs_data, dict) and \
all(isinstance(item, dict) for item in pos_top_logprobs_data.values()) and \
len(pos_top_logprobs_data.keys()) > 0): # [logprob, token_id, token_str]
LOG.warning(f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.")
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
current_target_token_ids.append(list(range(self.kd_online_topk)))
if not (
isinstance(pos_top_logprobs_data, dict)
and all(
isinstance(item, dict)
for item in pos_top_logprobs_data.values()
)
and len(pos_top_logprobs_data.keys()) > 0
): # [logprob, token_id, token_str]
LOG.warning(
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
continue
@@ -322,23 +407,31 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
pos_token_ids = pos_top_logprobs_data.keys()
pos_logprobs_dict = pos_top_logprobs_data.values()
pos_token_ids = [int(token_id) for token_id in pos_token_ids]
pos_logprobs_raw = [float(logprob.get("logprob", -float("inf"))) for logprob in pos_logprobs_dict]
pos_logprobs_raw = [
float(logprob.get("logprob", -float("inf")))
for logprob in pos_logprobs_dict
]
# Ensure correct length (top_k)
if len(pos_logprobs_raw) < self.kd_online_topk:
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
LOG.warning(f"Padding position {i} with {pad_len} top-k tokens and logprobs.")
LOG.warning(
f"Padding position {i} with {pad_len} top-k tokens and logprobs."
)
pos_logprobs_raw.extend([-float("inf")] * pad_len)
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
# truncate to top_k in case the response was longer
current_target_token_ids.append(pos_token_ids[:self.kd_online_topk])
current_target_token_ids.append(
pos_token_ids[: self.kd_online_topk]
)
# normalized_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk])
# current_target_logprobs.append(normalized_logprobs_for_position)
# don't normalize for now as the probs seem to sum to 1.0 already
current_target_logprobs.append(pos_logprobs_raw[:self.kd_online_topk])
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
# Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
@@ -347,8 +440,12 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
current_target_mask.append([1] * self.kd_online_topk)
else:
# Pad if no logprobs for this position (either due to length mismatch or None entry)
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
current_target_token_ids.append(list(range(self.kd_online_topk)))
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
@@ -364,18 +461,24 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
LOG.error(f"Error fetching logprobs from online teacher: {e}")
raise e
# ret_logprobs_data will be returned with empty lists, handled by the caller.
except Exception as e: # Catch other potential errors during processing
LOG.error(f"Unexpected error processing API response in fetch_online_logprobs: {e}", exc_info=True)
except Exception as e: # Catch other potential errors during processing
LOG.error(
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
exc_info=True,
)
raise e
return ret_logprobs_data
def __call__(self, features: List[List[Dict[str, Any]]],
return_tensors: Optional[str] = None) -> Dict[str, Any]:
def __call__(
self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None
) -> Dict[str, Any]:
if not features:
return super().__call__(features, return_tensors=return_tensors)
for sub_batch_features in features: # sub_batch_features is List[Dict[str, Any]]
for (
sub_batch_features
) in features: # sub_batch_features is List[Dict[str, Any]]
if not sub_batch_features:
continue
@@ -386,7 +489,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
for item_dict in sub_batch_features:
if not isinstance(item_dict, dict):
LOG.warning(f"Skipping non-dict item in sub_batch_features: {item_dict}")
LOG.warning(
f"Skipping non-dict item in sub_batch_features: {item_dict}"
)
continue
current_input_ids = item_dict.get("input_ids")
@@ -394,8 +499,16 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
if current_input_ids is not None and current_labels is not None:
# Ensure input_ids and labels are lists of ints for JSON serialization
input_ids_list = current_input_ids.tolist() if hasattr(current_input_ids, "tolist") else list(current_input_ids)
labels_list = current_labels.tolist() if hasattr(current_labels, "tolist") else list(current_labels)
input_ids_list = (
current_input_ids.tolist()
if hasattr(current_input_ids, "tolist")
else list(current_input_ids)
)
labels_list = (
current_labels.tolist()
if hasattr(current_labels, "tolist")
else list(current_labels)
)
input_ids_for_api_call.append(input_ids_list)
labels_for_api_call.append(labels_list)
@@ -408,29 +521,46 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
item_dict.setdefault("target_mask", [])
# print(items_for_api_call)
if items_for_api_call: # Only call API if there's something to process
if items_for_api_call: # Only call API if there's something to process
if self.kd_online_server == "sglang":
api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(input_ids_for_api_call, labels_for_api_call)
api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(
input_ids_for_api_call, labels_for_api_call
)
else:
api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(input_ids_for_api_call, labels_for_api_call)
api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(
input_ids_for_api_call, labels_for_api_call
)
# api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask"
# Each value is a list, corresponding to items_for_api_call
for i, item_to_update in enumerate(items_for_api_call):
# TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly.
if api_responses_for_sub_batch and \
i < len(api_responses_for_sub_batch): # Check bounds
assert len(api_responses_for_sub_batch["target_token_ids"][i]) == len(item_to_update["input_ids"])
assert len(api_responses_for_sub_batch["target_logprobs"][i]) == len(item_to_update["input_ids"])
assert len(api_responses_for_sub_batch["target_mask"][i]) == len(item_to_update["labels"])
item_to_update["target_token_ids"] = api_responses_for_sub_batch["target_token_ids"][i]
item_to_update["target_logprobs"] = api_responses_for_sub_batch["target_logprobs"][i]
item_to_update["target_mask"] = api_responses_for_sub_batch["target_mask"][i]
if api_responses_for_sub_batch and i < len(
api_responses_for_sub_batch["target_token_ids"]
): # Check bounds
assert len(
api_responses_for_sub_batch["target_token_ids"][i]
) == len(item_to_update["input_ids"])
assert len(
api_responses_for_sub_batch["target_logprobs"][i]
) == len(item_to_update["input_ids"])
assert len(
api_responses_for_sub_batch["target_mask"][i]
) == len(item_to_update["labels"])
item_to_update["target_token_ids"] = (
api_responses_for_sub_batch["target_token_ids"][i]
)
item_to_update["target_logprobs"] = api_responses_for_sub_batch[
"target_logprobs"
][i]
item_to_update["target_mask"] = api_responses_for_sub_batch[
"target_mask"
][i]
else:
# API call failed for this item, or response was shorter than expected.
# Ensure KD fields are initialized as empty lists.
LOG.warning(
f"Failed to get online KD data for an item in the batch (index {i}), or API response was too short. "
f" (index {i}), or API response was too short. "
f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}"
)
item_to_update.setdefault("target_token_ids", [])