From b75db1361536220c37cabea5da487c2cc34bbce2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 26 May 2025 21:32:11 -0400 Subject: [PATCH] fix check --- .../kd/collator_online_teacher.py | 294 +++++++++++++----- 1 file changed, 212 insertions(+), 82 deletions(-) diff --git a/src/axolotl/integrations/kd/collator_online_teacher.py b/src/axolotl/integrations/kd/collator_online_teacher.py index 1a0173760..503006f15 100644 --- a/src/axolotl/integrations/kd/collator_online_teacher.py +++ b/src/axolotl/integrations/kd/collator_online_teacher.py @@ -1,9 +1,13 @@ +""" +Packed data loader for online teacher training supporting vllm and sglang. +""" +import logging +from typing import Any, Dict, List, Optional + import pandas as pd import requests -import logging -from typing import List, Optional, Dict, Any - import torch + from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.data.utils import retry_on_request_exceptions @@ -11,6 +15,9 @@ LOG = logging.getLogger(__name__) class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): + """ + Collator for online teacher training. + """ DEFAULT_LABEL_PAD_TOKEN_ID: int = -100 def __init__( @@ -25,11 +32,15 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): super().__init__(*args, **kwargs) if kd_online_server_base_url is None: - raise ValueError("kd_online_server_base_url must be provided for OnlineTeacherDataloader") + raise ValueError( + "kd_online_server_base_url must be provided for OnlineTeacherDataloader" + ) if kd_online_topk is None or kd_online_topk <= 0: - raise ValueError("kd_online_topk must be a positive integer for OnlineTeacherDataloader") + raise ValueError( + "kd_online_topk must be a positive integer for OnlineTeacherDataloader" + ) - self.kd_online_server_base_url = kd_online_server_base_url.rstrip('/') + self.kd_online_server_base_url = kd_online_server_base_url.rstrip("/") self.kd_online_topk = kd_online_topk self.kd_temperature = kd_temperature self.kd_online_server = kd_online_server @@ -40,7 +51,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs. """ if not raw_logprobs or self.kd_online_topk == 0: - return [-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else [] + return ( + [-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else [] + ) # Ensure raw_logprobs matches kd_online_topk length for tensor operations # This should ideally be handled by the caller ensuring correct padding/truncation first @@ -50,9 +63,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): f"Logprobs length mismatch in _normalize_logprobs. " f"Expected {self.kd_online_topk}, got {len(raw_logprobs)}. Will pad/truncate." ) - padded_logprobs = raw_logprobs[:self.kd_online_topk] + padded_logprobs = raw_logprobs[: self.kd_online_topk] if len(padded_logprobs) < self.kd_online_topk: - padded_logprobs.extend([-float("inf")] * (self.kd_online_topk - len(padded_logprobs))) + padded_logprobs.extend( + [-float("inf")] * (self.kd_online_topk - len(padded_logprobs)) + ) raw_logprobs = padded_logprobs try: @@ -60,19 +75,27 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): # Convert logprobs at T_online to probabilities # use log sum exp trick to avoid underflow - position_logprobs_lse = torch.logsumexp(position_logprobs_tensor, dim=-1, keepdim=True) - teacher_probs_t_online = torch.exp(position_logprobs_tensor - position_logprobs_lse) + position_logprobs_lse = torch.logsumexp( + position_logprobs_tensor, dim=-1, keepdim=True + ) + teacher_probs_t_online = torch.exp( + position_logprobs_tensor - position_logprobs_lse + ) # Normalize probabilities (sum to 1) # This is important if the top-k from server aren't a full distribution teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=0, keepdim=True) if teacher_probs_t_online_sum.item() > 1e-9: - teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum + teacher_probs_t_online = ( + teacher_probs_t_online / teacher_probs_t_online_sum + ) else: # If sum is zero, create uniform distribution to avoid NaN/Inf later # This can happen if all raw_logprobs are -inf if self.kd_online_topk > 0: - teacher_probs_t_online = torch.ones_like(teacher_probs_t_online) / self.kd_online_topk + teacher_probs_t_online = ( + torch.ones_like(teacher_probs_t_online) / self.kd_online_topk + ) # else: leave as is, will result in -inf logprobs # # teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum( @@ -83,12 +106,17 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): return final_logprobs_tensor.tolist() except Exception as e: - LOG.error(f"Error during online logprob scaling: {e}. Returning raw logprobs.", exc_info=True) + LOG.error( + f"Error during online logprob scaling: {e}. Returning raw logprobs.", + exc_info=True, + ) # Fallback to (padded/truncated) raw logprobs if scaling fails return raw_logprobs @retry_on_request_exceptions(max_retries=10, delay=5) - def fetch_online_logprobs_sglang(self, batch_input_ids: List[List[int]], labels: List[List[int]]): + def fetch_online_logprobs_sglang( + self, batch_input_ids: List[List[int]], labels: List[List[int]] + ): """ Fetches logprobs from an online teacher served by vllm for a batch of input_ids. Assumes API returns token IDs as strings in logprob dictionary keys. @@ -130,69 +158,96 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): # Return empty data; items processed later will get default empty KD fields return ret_logprobs_data - - for sequence_data, seq_input_ids, seq_labels in zip(api_data, batch_input_ids, labels): + for sequence_data, seq_input_ids, seq_labels in zip( + api_data, batch_input_ids, labels + ): current_target_logprobs = [] current_target_token_ids = [] current_target_mask = [] meta_info = sequence_data.pop("meta_info", {}) # Ensure input_top_logprobs is a list - input_top_logprobs: Optional[list[None |list[tuple]]] = meta_info.pop("input_top_logprobs", []) + input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop( + "input_top_logprobs", [] + ) if not isinstance(input_top_logprobs, list): - LOG.warning(f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.") - input_top_logprobs = [] # Treat as empty + LOG.warning( + f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence." + ) + input_top_logprobs = [] # Treat as empty # basic check that the logprob data len matches the input len, so no need to handle padding assert len(seq_input_ids) == len(input_top_logprobs) - for i, input_id, label in zip(range(len(seq_input_ids)), seq_input_ids, seq_labels): + for i, input_id, label in zip( + range(len(seq_input_ids)), seq_input_ids, seq_labels + ): if i < len(input_top_logprobs) and input_top_logprobs[i] is None: # this is always the case for the first token. # there is never logprob data for the first token since that's a true input # so we replace the None value with padding data - current_target_logprobs.append([-float("inf")] * self.kd_online_topk) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) current_target_token_ids.append([0] * self.kd_online_topk) current_target_mask.append([0] * self.kd_online_topk) - elif i < len(input_top_logprobs) and input_top_logprobs[i] is not None: + elif ( + i < len(input_top_logprobs) + and input_top_logprobs[i] is not None + ): pos_top_logprobs_data = input_top_logprobs[i] # Ensure pos_top_logprobs_data is a list of lists as expected - if not (isinstance(pos_top_logprobs_data, list) and \ - all(isinstance(item, list) for item in pos_top_logprobs_data) and \ - len(pos_top_logprobs_data) > 0 and len(pos_top_logprobs_data[0]) == 3): # [logprob, token_id, token_str] - LOG.warning(f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.") - current_target_logprobs.append([-float("inf")] * self.kd_online_topk) + if not ( + isinstance(pos_top_logprobs_data, list) + and all( + isinstance(item, list) for item in pos_top_logprobs_data + ) + and len(pos_top_logprobs_data) > 0 + and len(pos_top_logprobs_data[0]) == 3 + ): # [logprob, token_id, token_str] + LOG.warning( + f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position." + ) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) current_target_token_ids.append([0] * self.kd_online_topk) current_target_mask.append([0] * self.kd_online_topk) continue # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids - pos_logprobs_raw, pos_token_ids, _ = [list(row) for row in zip(*pos_top_logprobs_data)] + pos_logprobs_raw, pos_token_ids, _ = [ + list(row) for row in zip(*pos_top_logprobs_data) + ] # Ensure correct length (top_k) if len(pos_logprobs_raw) < self.kd_online_topk: pad_len = self.kd_online_topk - len(pos_logprobs_raw) pos_logprobs_raw.extend([-float("inf")] * pad_len) - pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id + pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id # truncate to top_k in case the response was longer - current_target_token_ids.append(pos_token_ids[:self.kd_online_topk]) - scaled_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk]) + current_target_token_ids.append( + pos_token_ids[: self.kd_online_topk] + ) + scaled_logprobs_for_position = self._normalize_logprobs( + pos_logprobs_raw[: self.kd_online_topk] + ) current_target_logprobs.append(scaled_logprobs_for_position) # Mask depends on the corresponding label for the student - label_for_pos = seq_labels[i] if i < len(seq_labels) else self.DEFAULT_LABEL_PAD_TOKEN_ID - if label_for_pos == self.DEFAULT_LABEL_PAD_TOKEN_ID: + if label == self.DEFAULT_LABEL_PAD_TOKEN_ID: current_target_mask.append([0] * self.kd_online_topk) else: current_target_mask.append([1] * self.kd_online_topk) else: # Pad if no logprobs for this position (either due to length mismatch or None entry) - current_target_logprobs.append([-float("inf")] * self.kd_online_topk) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) current_target_token_ids.append([0] * self.kd_online_topk) current_target_mask.append([0] * self.kd_online_topk) - ret_logprobs_data["target_token_ids"].append(current_target_token_ids) ret_logprobs_data["target_logprobs"].append(current_target_logprobs) ret_logprobs_data["target_mask"].append(current_target_mask) @@ -201,8 +256,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): LOG.error(f"Error fetching logprobs from online teacher: {e}") raise e # ret_logprobs_data will be returned with empty lists, handled by the caller. - except Exception as e: # Catch other potential errors during processing - LOG.error(f"Unexpected error processing API response in fetch_online_logprobs: {e}", exc_info=True) + except Exception as e: # Catch other potential errors during processing + LOG.error( + f"Unexpected error processing API response in fetch_online_logprobs: {e}", + exc_info=True, + ) raise e # Return initialized empty data # return { @@ -211,11 +269,12 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): # "target_mask": [], # } - return ret_logprobs_data @retry_on_request_exceptions(max_retries=10, delay=5) - def fetch_online_logprobs_vllm(self, batch_input_ids: List[List[int]], labels: List[List[int]]): + def fetch_online_logprobs_vllm( + self, batch_input_ids: List[List[int]], labels: List[List[int]] + ): """ Fetches logprobs from an online teacher served by vllm for a batch of input_ids. Assumes API returns token IDs as strings in logprob dictionary keys. @@ -258,7 +317,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): # Return empty data; items processed later will get default empty KD fields return ret_logprobs_data - for sequence_data, seq_input_ids, seq_labels in zip(choices, batch_input_ids, labels): + for sequence_data, seq_input_ids, seq_labels in zip( + choices, batch_input_ids, labels + ): # seq_input_ids: List[int] # seq_labels: List[int] @@ -267,7 +328,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): current_target_mask = [] # Ensure input_top_logprobs is a list - input_top_logprobs: Optional[list[None | list[tuple]]] = sequence_data.pop("prompt_logprobs", []) + input_top_logprobs: Optional[list[None | list[tuple]]] = ( + sequence_data.pop("prompt_logprobs", []) + ) """ vllm api data for prompt logprobs looks like: "prompt_logprobs": [ @@ -289,8 +352,10 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): } """ if not isinstance(input_top_logprobs, list): - LOG.warning(f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.") - input_top_logprobs = [] # Treat as empty + LOG.warning( + f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence." + ) + input_top_logprobs = [] # Treat as empty # basic check that the logprob data len matches the input len, so no need to handle padding assert len(seq_input_ids) == len(input_top_logprobs) @@ -298,23 +363,43 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): # generate a hash over seq_input_ids and convert it to an int hash_input_ids: int = hash(tuple(seq_input_ids)) - for i, input_id, label in zip(range(len(seq_input_ids)), seq_input_ids, seq_labels): + for i, _, label in zip( + range(len(seq_input_ids)), seq_input_ids, seq_labels + ): if i < len(input_top_logprobs) and input_top_logprobs[i] is None: # this is always the case for the first token. # there is never logprob data for the first token since that's a true input # so we replace the None value with padding data - current_target_logprobs.append([-float("inf")] * self.kd_online_topk) - current_target_token_ids.append(list(range(self.kd_online_topk))) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append( + list(range(self.kd_online_topk)) + ) current_target_mask.append([0] * self.kd_online_topk) - elif i < len(input_top_logprobs) and input_top_logprobs[i] is not None: + elif ( + i < len(input_top_logprobs) + and input_top_logprobs[i] is not None + ): pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] # Ensure pos_top_logprobs_data is a list of lists as expected - if not (isinstance(pos_top_logprobs_data, dict) and \ - all(isinstance(item, dict) for item in pos_top_logprobs_data.values()) and \ - len(pos_top_logprobs_data.keys()) > 0): # [logprob, token_id, token_str] - LOG.warning(f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.") - current_target_logprobs.append([-float("inf")] * self.kd_online_topk) - current_target_token_ids.append(list(range(self.kd_online_topk))) + if not ( + isinstance(pos_top_logprobs_data, dict) + and all( + isinstance(item, dict) + for item in pos_top_logprobs_data.values() + ) + and len(pos_top_logprobs_data.keys()) > 0 + ): # [logprob, token_id, token_str] + LOG.warning( + f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position." + ) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append( + list(range(self.kd_online_topk)) + ) current_target_mask.append([0] * self.kd_online_topk) continue @@ -322,23 +407,31 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): pos_token_ids = pos_top_logprobs_data.keys() pos_logprobs_dict = pos_top_logprobs_data.values() pos_token_ids = [int(token_id) for token_id in pos_token_ids] - pos_logprobs_raw = [float(logprob.get("logprob", -float("inf"))) for logprob in pos_logprobs_dict] + pos_logprobs_raw = [ + float(logprob.get("logprob", -float("inf"))) + for logprob in pos_logprobs_dict + ] # Ensure correct length (top_k) if len(pos_logprobs_raw) < self.kd_online_topk: pad_len = self.kd_online_topk - len(pos_logprobs_raw) - LOG.warning(f"Padding position {i} with {pad_len} top-k tokens and logprobs.") + LOG.warning( + f"Padding position {i} with {pad_len} top-k tokens and logprobs." + ) pos_logprobs_raw.extend([-float("inf")] * pad_len) - pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id + pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id # truncate to top_k in case the response was longer - current_target_token_ids.append(pos_token_ids[:self.kd_online_topk]) + current_target_token_ids.append( + pos_token_ids[: self.kd_online_topk] + ) # normalized_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk]) # current_target_logprobs.append(normalized_logprobs_for_position) # don't normalize for now as the probs seem to sum to 1.0 already - current_target_logprobs.append(pos_logprobs_raw[:self.kd_online_topk]) - + current_target_logprobs.append( + pos_logprobs_raw[: self.kd_online_topk] + ) # Mask depends on the corresponding label for the student if label == self.DEFAULT_LABEL_PAD_TOKEN_ID: @@ -347,8 +440,12 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): current_target_mask.append([1] * self.kd_online_topk) else: # Pad if no logprobs for this position (either due to length mismatch or None entry) - current_target_logprobs.append([-float("inf")] * self.kd_online_topk) - current_target_token_ids.append(list(range(self.kd_online_topk))) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append( + list(range(self.kd_online_topk)) + ) current_target_mask.append([0] * self.kd_online_topk) ret_logprobs_data["target_token_ids"].append(current_target_token_ids) @@ -364,18 +461,24 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): LOG.error(f"Error fetching logprobs from online teacher: {e}") raise e # ret_logprobs_data will be returned with empty lists, handled by the caller. - except Exception as e: # Catch other potential errors during processing - LOG.error(f"Unexpected error processing API response in fetch_online_logprobs: {e}", exc_info=True) + except Exception as e: # Catch other potential errors during processing + LOG.error( + f"Unexpected error processing API response in fetch_online_logprobs: {e}", + exc_info=True, + ) raise e return ret_logprobs_data - def __call__(self, features: List[List[Dict[str, Any]]], - return_tensors: Optional[str] = None) -> Dict[str, Any]: + def __call__( + self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None + ) -> Dict[str, Any]: if not features: return super().__call__(features, return_tensors=return_tensors) - for sub_batch_features in features: # sub_batch_features is List[Dict[str, Any]] + for ( + sub_batch_features + ) in features: # sub_batch_features is List[Dict[str, Any]] if not sub_batch_features: continue @@ -386,7 +489,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): for item_dict in sub_batch_features: if not isinstance(item_dict, dict): - LOG.warning(f"Skipping non-dict item in sub_batch_features: {item_dict}") + LOG.warning( + f"Skipping non-dict item in sub_batch_features: {item_dict}" + ) continue current_input_ids = item_dict.get("input_ids") @@ -394,8 +499,16 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): if current_input_ids is not None and current_labels is not None: # Ensure input_ids and labels are lists of ints for JSON serialization - input_ids_list = current_input_ids.tolist() if hasattr(current_input_ids, "tolist") else list(current_input_ids) - labels_list = current_labels.tolist() if hasattr(current_labels, "tolist") else list(current_labels) + input_ids_list = ( + current_input_ids.tolist() + if hasattr(current_input_ids, "tolist") + else list(current_input_ids) + ) + labels_list = ( + current_labels.tolist() + if hasattr(current_labels, "tolist") + else list(current_labels) + ) input_ids_for_api_call.append(input_ids_list) labels_for_api_call.append(labels_list) @@ -408,29 +521,46 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): item_dict.setdefault("target_mask", []) # print(items_for_api_call) - if items_for_api_call: # Only call API if there's something to process + if items_for_api_call: # Only call API if there's something to process if self.kd_online_server == "sglang": - api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(input_ids_for_api_call, labels_for_api_call) + api_responses_for_sub_batch = self.fetch_online_logprobs_sglang( + input_ids_for_api_call, labels_for_api_call + ) else: - api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(input_ids_for_api_call, labels_for_api_call) + api_responses_for_sub_batch = self.fetch_online_logprobs_vllm( + input_ids_for_api_call, labels_for_api_call + ) # api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask" # Each value is a list, corresponding to items_for_api_call for i, item_to_update in enumerate(items_for_api_call): # TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly. - if api_responses_for_sub_batch and \ - i < len(api_responses_for_sub_batch): # Check bounds - assert len(api_responses_for_sub_batch["target_token_ids"][i]) == len(item_to_update["input_ids"]) - assert len(api_responses_for_sub_batch["target_logprobs"][i]) == len(item_to_update["input_ids"]) - assert len(api_responses_for_sub_batch["target_mask"][i]) == len(item_to_update["labels"]) - item_to_update["target_token_ids"] = api_responses_for_sub_batch["target_token_ids"][i] - item_to_update["target_logprobs"] = api_responses_for_sub_batch["target_logprobs"][i] - item_to_update["target_mask"] = api_responses_for_sub_batch["target_mask"][i] + if api_responses_for_sub_batch and i < len( + api_responses_for_sub_batch["target_token_ids"] + ): # Check bounds + assert len( + api_responses_for_sub_batch["target_token_ids"][i] + ) == len(item_to_update["input_ids"]) + assert len( + api_responses_for_sub_batch["target_logprobs"][i] + ) == len(item_to_update["input_ids"]) + assert len( + api_responses_for_sub_batch["target_mask"][i] + ) == len(item_to_update["labels"]) + item_to_update["target_token_ids"] = ( + api_responses_for_sub_batch["target_token_ids"][i] + ) + item_to_update["target_logprobs"] = api_responses_for_sub_batch[ + "target_logprobs" + ][i] + item_to_update["target_mask"] = api_responses_for_sub_batch[ + "target_mask" + ][i] else: # API call failed for this item, or response was shorter than expected. # Ensure KD fields are initialized as empty lists. LOG.warning( - f"Failed to get online KD data for an item in the batch (index {i}), or API response was too short. " + f" (index {i}), or API response was too short. " f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}" ) item_to_update.setdefault("target_token_ids", [])