fix check
This commit is contained in:
@@ -1,9 +1,13 @@
|
|||||||
|
"""
|
||||||
|
Packed data loader for online teacher training supporting vllm and sglang.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import requests
|
import requests
|
||||||
import logging
|
|
||||||
from typing import List, Optional, Dict, Any
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
|
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
|
||||||
from axolotl.utils.data.utils import retry_on_request_exceptions
|
from axolotl.utils.data.utils import retry_on_request_exceptions
|
||||||
|
|
||||||
@@ -11,6 +15,9 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||||
|
"""
|
||||||
|
Collator for online teacher training.
|
||||||
|
"""
|
||||||
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
|
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -25,11 +32,15 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
if kd_online_server_base_url is None:
|
if kd_online_server_base_url is None:
|
||||||
raise ValueError("kd_online_server_base_url must be provided for OnlineTeacherDataloader")
|
raise ValueError(
|
||||||
|
"kd_online_server_base_url must be provided for OnlineTeacherDataloader"
|
||||||
|
)
|
||||||
if kd_online_topk is None or kd_online_topk <= 0:
|
if kd_online_topk is None or kd_online_topk <= 0:
|
||||||
raise ValueError("kd_online_topk must be a positive integer for OnlineTeacherDataloader")
|
raise ValueError(
|
||||||
|
"kd_online_topk must be a positive integer for OnlineTeacherDataloader"
|
||||||
|
)
|
||||||
|
|
||||||
self.kd_online_server_base_url = kd_online_server_base_url.rstrip('/')
|
self.kd_online_server_base_url = kd_online_server_base_url.rstrip("/")
|
||||||
self.kd_online_topk = kd_online_topk
|
self.kd_online_topk = kd_online_topk
|
||||||
self.kd_temperature = kd_temperature
|
self.kd_temperature = kd_temperature
|
||||||
self.kd_online_server = kd_online_server
|
self.kd_online_server = kd_online_server
|
||||||
@@ -40,7 +51,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
|
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
|
||||||
"""
|
"""
|
||||||
if not raw_logprobs or self.kd_online_topk == 0:
|
if not raw_logprobs or self.kd_online_topk == 0:
|
||||||
return [-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
|
return (
|
||||||
|
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
|
||||||
|
)
|
||||||
|
|
||||||
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
|
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
|
||||||
# This should ideally be handled by the caller ensuring correct padding/truncation first
|
# This should ideally be handled by the caller ensuring correct padding/truncation first
|
||||||
@@ -50,9 +63,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
f"Logprobs length mismatch in _normalize_logprobs. "
|
f"Logprobs length mismatch in _normalize_logprobs. "
|
||||||
f"Expected {self.kd_online_topk}, got {len(raw_logprobs)}. Will pad/truncate."
|
f"Expected {self.kd_online_topk}, got {len(raw_logprobs)}. Will pad/truncate."
|
||||||
)
|
)
|
||||||
padded_logprobs = raw_logprobs[:self.kd_online_topk]
|
padded_logprobs = raw_logprobs[: self.kd_online_topk]
|
||||||
if len(padded_logprobs) < self.kd_online_topk:
|
if len(padded_logprobs) < self.kd_online_topk:
|
||||||
padded_logprobs.extend([-float("inf")] * (self.kd_online_topk - len(padded_logprobs)))
|
padded_logprobs.extend(
|
||||||
|
[-float("inf")] * (self.kd_online_topk - len(padded_logprobs))
|
||||||
|
)
|
||||||
raw_logprobs = padded_logprobs
|
raw_logprobs = padded_logprobs
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -60,19 +75,27 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
# Convert logprobs at T_online to probabilities
|
# Convert logprobs at T_online to probabilities
|
||||||
# use log sum exp trick to avoid underflow
|
# use log sum exp trick to avoid underflow
|
||||||
position_logprobs_lse = torch.logsumexp(position_logprobs_tensor, dim=-1, keepdim=True)
|
position_logprobs_lse = torch.logsumexp(
|
||||||
teacher_probs_t_online = torch.exp(position_logprobs_tensor - position_logprobs_lse)
|
position_logprobs_tensor, dim=-1, keepdim=True
|
||||||
|
)
|
||||||
|
teacher_probs_t_online = torch.exp(
|
||||||
|
position_logprobs_tensor - position_logprobs_lse
|
||||||
|
)
|
||||||
|
|
||||||
# Normalize probabilities (sum to 1)
|
# Normalize probabilities (sum to 1)
|
||||||
# This is important if the top-k from server aren't a full distribution
|
# This is important if the top-k from server aren't a full distribution
|
||||||
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=0, keepdim=True)
|
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=0, keepdim=True)
|
||||||
if teacher_probs_t_online_sum.item() > 1e-9:
|
if teacher_probs_t_online_sum.item() > 1e-9:
|
||||||
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
|
teacher_probs_t_online = (
|
||||||
|
teacher_probs_t_online / teacher_probs_t_online_sum
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# If sum is zero, create uniform distribution to avoid NaN/Inf later
|
# If sum is zero, create uniform distribution to avoid NaN/Inf later
|
||||||
# This can happen if all raw_logprobs are -inf
|
# This can happen if all raw_logprobs are -inf
|
||||||
if self.kd_online_topk > 0:
|
if self.kd_online_topk > 0:
|
||||||
teacher_probs_t_online = torch.ones_like(teacher_probs_t_online) / self.kd_online_topk
|
teacher_probs_t_online = (
|
||||||
|
torch.ones_like(teacher_probs_t_online) / self.kd_online_topk
|
||||||
|
)
|
||||||
# else: leave as is, will result in -inf logprobs
|
# else: leave as is, will result in -inf logprobs
|
||||||
#
|
#
|
||||||
# teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum(
|
# teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum(
|
||||||
@@ -83,12 +106,17 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
return final_logprobs_tensor.tolist()
|
return final_logprobs_tensor.tolist()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.error(f"Error during online logprob scaling: {e}. Returning raw logprobs.", exc_info=True)
|
LOG.error(
|
||||||
|
f"Error during online logprob scaling: {e}. Returning raw logprobs.",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
# Fallback to (padded/truncated) raw logprobs if scaling fails
|
# Fallback to (padded/truncated) raw logprobs if scaling fails
|
||||||
return raw_logprobs
|
return raw_logprobs
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=10, delay=5)
|
@retry_on_request_exceptions(max_retries=10, delay=5)
|
||||||
def fetch_online_logprobs_sglang(self, batch_input_ids: List[List[int]], labels: List[List[int]]):
|
def fetch_online_logprobs_sglang(
|
||||||
|
self, batch_input_ids: List[List[int]], labels: List[List[int]]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
|
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
|
||||||
Assumes API returns token IDs as strings in logprob dictionary keys.
|
Assumes API returns token IDs as strings in logprob dictionary keys.
|
||||||
@@ -130,69 +158,96 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
# Return empty data; items processed later will get default empty KD fields
|
# Return empty data; items processed later will get default empty KD fields
|
||||||
return ret_logprobs_data
|
return ret_logprobs_data
|
||||||
|
|
||||||
|
for sequence_data, seq_input_ids, seq_labels in zip(
|
||||||
for sequence_data, seq_input_ids, seq_labels in zip(api_data, batch_input_ids, labels):
|
api_data, batch_input_ids, labels
|
||||||
|
):
|
||||||
current_target_logprobs = []
|
current_target_logprobs = []
|
||||||
current_target_token_ids = []
|
current_target_token_ids = []
|
||||||
current_target_mask = []
|
current_target_mask = []
|
||||||
|
|
||||||
meta_info = sequence_data.pop("meta_info", {})
|
meta_info = sequence_data.pop("meta_info", {})
|
||||||
# Ensure input_top_logprobs is a list
|
# Ensure input_top_logprobs is a list
|
||||||
input_top_logprobs: Optional[list[None |list[tuple]]] = meta_info.pop("input_top_logprobs", [])
|
input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop(
|
||||||
|
"input_top_logprobs", []
|
||||||
|
)
|
||||||
if not isinstance(input_top_logprobs, list):
|
if not isinstance(input_top_logprobs, list):
|
||||||
LOG.warning(f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.")
|
LOG.warning(
|
||||||
input_top_logprobs = [] # Treat as empty
|
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
|
||||||
|
)
|
||||||
|
input_top_logprobs = [] # Treat as empty
|
||||||
|
|
||||||
# basic check that the logprob data len matches the input len, so no need to handle padding
|
# basic check that the logprob data len matches the input len, so no need to handle padding
|
||||||
assert len(seq_input_ids) == len(input_top_logprobs)
|
assert len(seq_input_ids) == len(input_top_logprobs)
|
||||||
|
|
||||||
for i, input_id, label in zip(range(len(seq_input_ids)), seq_input_ids, seq_labels):
|
for i, input_id, label in zip(
|
||||||
|
range(len(seq_input_ids)), seq_input_ids, seq_labels
|
||||||
|
):
|
||||||
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
||||||
# this is always the case for the first token.
|
# this is always the case for the first token.
|
||||||
# there is never logprob data for the first token since that's a true input
|
# there is never logprob data for the first token since that's a true input
|
||||||
# so we replace the None value with padding data
|
# so we replace the None value with padding data
|
||||||
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
|
current_target_logprobs.append(
|
||||||
|
[-float("inf")] * self.kd_online_topk
|
||||||
|
)
|
||||||
current_target_token_ids.append([0] * self.kd_online_topk)
|
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||||
current_target_mask.append([0] * self.kd_online_topk)
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
elif i < len(input_top_logprobs) and input_top_logprobs[i] is not None:
|
elif (
|
||||||
|
i < len(input_top_logprobs)
|
||||||
|
and input_top_logprobs[i] is not None
|
||||||
|
):
|
||||||
pos_top_logprobs_data = input_top_logprobs[i]
|
pos_top_logprobs_data = input_top_logprobs[i]
|
||||||
# Ensure pos_top_logprobs_data is a list of lists as expected
|
# Ensure pos_top_logprobs_data is a list of lists as expected
|
||||||
if not (isinstance(pos_top_logprobs_data, list) and \
|
if not (
|
||||||
all(isinstance(item, list) for item in pos_top_logprobs_data) and \
|
isinstance(pos_top_logprobs_data, list)
|
||||||
len(pos_top_logprobs_data) > 0 and len(pos_top_logprobs_data[0]) == 3): # [logprob, token_id, token_str]
|
and all(
|
||||||
LOG.warning(f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.")
|
isinstance(item, list) for item in pos_top_logprobs_data
|
||||||
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
|
)
|
||||||
|
and len(pos_top_logprobs_data) > 0
|
||||||
|
and len(pos_top_logprobs_data[0]) == 3
|
||||||
|
): # [logprob, token_id, token_str]
|
||||||
|
LOG.warning(
|
||||||
|
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
|
||||||
|
)
|
||||||
|
current_target_logprobs.append(
|
||||||
|
[-float("inf")] * self.kd_online_topk
|
||||||
|
)
|
||||||
current_target_token_ids.append([0] * self.kd_online_topk)
|
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||||
current_target_mask.append([0] * self.kd_online_topk)
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
|
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
|
||||||
pos_logprobs_raw, pos_token_ids, _ = [list(row) for row in zip(*pos_top_logprobs_data)]
|
pos_logprobs_raw, pos_token_ids, _ = [
|
||||||
|
list(row) for row in zip(*pos_top_logprobs_data)
|
||||||
|
]
|
||||||
|
|
||||||
# Ensure correct length (top_k)
|
# Ensure correct length (top_k)
|
||||||
if len(pos_logprobs_raw) < self.kd_online_topk:
|
if len(pos_logprobs_raw) < self.kd_online_topk:
|
||||||
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
|
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
|
||||||
pos_logprobs_raw.extend([-float("inf")] * pad_len)
|
pos_logprobs_raw.extend([-float("inf")] * pad_len)
|
||||||
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
|
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
|
||||||
|
|
||||||
# truncate to top_k in case the response was longer
|
# truncate to top_k in case the response was longer
|
||||||
current_target_token_ids.append(pos_token_ids[:self.kd_online_topk])
|
current_target_token_ids.append(
|
||||||
scaled_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk])
|
pos_token_ids[: self.kd_online_topk]
|
||||||
|
)
|
||||||
|
scaled_logprobs_for_position = self._normalize_logprobs(
|
||||||
|
pos_logprobs_raw[: self.kd_online_topk]
|
||||||
|
)
|
||||||
current_target_logprobs.append(scaled_logprobs_for_position)
|
current_target_logprobs.append(scaled_logprobs_for_position)
|
||||||
|
|
||||||
# Mask depends on the corresponding label for the student
|
# Mask depends on the corresponding label for the student
|
||||||
label_for_pos = seq_labels[i] if i < len(seq_labels) else self.DEFAULT_LABEL_PAD_TOKEN_ID
|
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
||||||
if label_for_pos == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
|
||||||
current_target_mask.append([0] * self.kd_online_topk)
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
else:
|
else:
|
||||||
current_target_mask.append([1] * self.kd_online_topk)
|
current_target_mask.append([1] * self.kd_online_topk)
|
||||||
else:
|
else:
|
||||||
# Pad if no logprobs for this position (either due to length mismatch or None entry)
|
# Pad if no logprobs for this position (either due to length mismatch or None entry)
|
||||||
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
|
current_target_logprobs.append(
|
||||||
|
[-float("inf")] * self.kd_online_topk
|
||||||
|
)
|
||||||
current_target_token_ids.append([0] * self.kd_online_topk)
|
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||||
current_target_mask.append([0] * self.kd_online_topk)
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
|
|
||||||
|
|
||||||
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
|
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
|
||||||
ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
|
ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
|
||||||
ret_logprobs_data["target_mask"].append(current_target_mask)
|
ret_logprobs_data["target_mask"].append(current_target_mask)
|
||||||
@@ -201,8 +256,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
||||||
raise e
|
raise e
|
||||||
# ret_logprobs_data will be returned with empty lists, handled by the caller.
|
# ret_logprobs_data will be returned with empty lists, handled by the caller.
|
||||||
except Exception as e: # Catch other potential errors during processing
|
except Exception as e: # Catch other potential errors during processing
|
||||||
LOG.error(f"Unexpected error processing API response in fetch_online_logprobs: {e}", exc_info=True)
|
LOG.error(
|
||||||
|
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
# Return initialized empty data
|
# Return initialized empty data
|
||||||
# return {
|
# return {
|
||||||
@@ -211,11 +269,12 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
# "target_mask": [],
|
# "target_mask": [],
|
||||||
# }
|
# }
|
||||||
|
|
||||||
|
|
||||||
return ret_logprobs_data
|
return ret_logprobs_data
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=10, delay=5)
|
@retry_on_request_exceptions(max_retries=10, delay=5)
|
||||||
def fetch_online_logprobs_vllm(self, batch_input_ids: List[List[int]], labels: List[List[int]]):
|
def fetch_online_logprobs_vllm(
|
||||||
|
self, batch_input_ids: List[List[int]], labels: List[List[int]]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
|
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
|
||||||
Assumes API returns token IDs as strings in logprob dictionary keys.
|
Assumes API returns token IDs as strings in logprob dictionary keys.
|
||||||
@@ -258,7 +317,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
# Return empty data; items processed later will get default empty KD fields
|
# Return empty data; items processed later will get default empty KD fields
|
||||||
return ret_logprobs_data
|
return ret_logprobs_data
|
||||||
|
|
||||||
for sequence_data, seq_input_ids, seq_labels in zip(choices, batch_input_ids, labels):
|
for sequence_data, seq_input_ids, seq_labels in zip(
|
||||||
|
choices, batch_input_ids, labels
|
||||||
|
):
|
||||||
# seq_input_ids: List[int]
|
# seq_input_ids: List[int]
|
||||||
# seq_labels: List[int]
|
# seq_labels: List[int]
|
||||||
|
|
||||||
@@ -267,7 +328,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
current_target_mask = []
|
current_target_mask = []
|
||||||
|
|
||||||
# Ensure input_top_logprobs is a list
|
# Ensure input_top_logprobs is a list
|
||||||
input_top_logprobs: Optional[list[None | list[tuple]]] = sequence_data.pop("prompt_logprobs", [])
|
input_top_logprobs: Optional[list[None | list[tuple]]] = (
|
||||||
|
sequence_data.pop("prompt_logprobs", [])
|
||||||
|
)
|
||||||
"""
|
"""
|
||||||
vllm api data for prompt logprobs looks like:
|
vllm api data for prompt logprobs looks like:
|
||||||
"prompt_logprobs": [
|
"prompt_logprobs": [
|
||||||
@@ -289,8 +352,10 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
if not isinstance(input_top_logprobs, list):
|
if not isinstance(input_top_logprobs, list):
|
||||||
LOG.warning(f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.")
|
LOG.warning(
|
||||||
input_top_logprobs = [] # Treat as empty
|
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
|
||||||
|
)
|
||||||
|
input_top_logprobs = [] # Treat as empty
|
||||||
|
|
||||||
# basic check that the logprob data len matches the input len, so no need to handle padding
|
# basic check that the logprob data len matches the input len, so no need to handle padding
|
||||||
assert len(seq_input_ids) == len(input_top_logprobs)
|
assert len(seq_input_ids) == len(input_top_logprobs)
|
||||||
@@ -298,23 +363,43 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
# generate a hash over seq_input_ids and convert it to an int
|
# generate a hash over seq_input_ids and convert it to an int
|
||||||
hash_input_ids: int = hash(tuple(seq_input_ids))
|
hash_input_ids: int = hash(tuple(seq_input_ids))
|
||||||
|
|
||||||
for i, input_id, label in zip(range(len(seq_input_ids)), seq_input_ids, seq_labels):
|
for i, _, label in zip(
|
||||||
|
range(len(seq_input_ids)), seq_input_ids, seq_labels
|
||||||
|
):
|
||||||
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
||||||
# this is always the case for the first token.
|
# this is always the case for the first token.
|
||||||
# there is never logprob data for the first token since that's a true input
|
# there is never logprob data for the first token since that's a true input
|
||||||
# so we replace the None value with padding data
|
# so we replace the None value with padding data
|
||||||
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
|
current_target_logprobs.append(
|
||||||
current_target_token_ids.append(list(range(self.kd_online_topk)))
|
[-float("inf")] * self.kd_online_topk
|
||||||
|
)
|
||||||
|
current_target_token_ids.append(
|
||||||
|
list(range(self.kd_online_topk))
|
||||||
|
)
|
||||||
current_target_mask.append([0] * self.kd_online_topk)
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
elif i < len(input_top_logprobs) and input_top_logprobs[i] is not None:
|
elif (
|
||||||
|
i < len(input_top_logprobs)
|
||||||
|
and input_top_logprobs[i] is not None
|
||||||
|
):
|
||||||
pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i]
|
pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i]
|
||||||
# Ensure pos_top_logprobs_data is a list of lists as expected
|
# Ensure pos_top_logprobs_data is a list of lists as expected
|
||||||
if not (isinstance(pos_top_logprobs_data, dict) and \
|
if not (
|
||||||
all(isinstance(item, dict) for item in pos_top_logprobs_data.values()) and \
|
isinstance(pos_top_logprobs_data, dict)
|
||||||
len(pos_top_logprobs_data.keys()) > 0): # [logprob, token_id, token_str]
|
and all(
|
||||||
LOG.warning(f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.")
|
isinstance(item, dict)
|
||||||
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
|
for item in pos_top_logprobs_data.values()
|
||||||
current_target_token_ids.append(list(range(self.kd_online_topk)))
|
)
|
||||||
|
and len(pos_top_logprobs_data.keys()) > 0
|
||||||
|
): # [logprob, token_id, token_str]
|
||||||
|
LOG.warning(
|
||||||
|
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
|
||||||
|
)
|
||||||
|
current_target_logprobs.append(
|
||||||
|
[-float("inf")] * self.kd_online_topk
|
||||||
|
)
|
||||||
|
current_target_token_ids.append(
|
||||||
|
list(range(self.kd_online_topk))
|
||||||
|
)
|
||||||
current_target_mask.append([0] * self.kd_online_topk)
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -322,23 +407,31 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
pos_token_ids = pos_top_logprobs_data.keys()
|
pos_token_ids = pos_top_logprobs_data.keys()
|
||||||
pos_logprobs_dict = pos_top_logprobs_data.values()
|
pos_logprobs_dict = pos_top_logprobs_data.values()
|
||||||
pos_token_ids = [int(token_id) for token_id in pos_token_ids]
|
pos_token_ids = [int(token_id) for token_id in pos_token_ids]
|
||||||
pos_logprobs_raw = [float(logprob.get("logprob", -float("inf"))) for logprob in pos_logprobs_dict]
|
pos_logprobs_raw = [
|
||||||
|
float(logprob.get("logprob", -float("inf")))
|
||||||
|
for logprob in pos_logprobs_dict
|
||||||
|
]
|
||||||
|
|
||||||
# Ensure correct length (top_k)
|
# Ensure correct length (top_k)
|
||||||
if len(pos_logprobs_raw) < self.kd_online_topk:
|
if len(pos_logprobs_raw) < self.kd_online_topk:
|
||||||
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
|
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
|
||||||
LOG.warning(f"Padding position {i} with {pad_len} top-k tokens and logprobs.")
|
LOG.warning(
|
||||||
|
f"Padding position {i} with {pad_len} top-k tokens and logprobs."
|
||||||
|
)
|
||||||
pos_logprobs_raw.extend([-float("inf")] * pad_len)
|
pos_logprobs_raw.extend([-float("inf")] * pad_len)
|
||||||
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
|
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
|
||||||
|
|
||||||
# truncate to top_k in case the response was longer
|
# truncate to top_k in case the response was longer
|
||||||
current_target_token_ids.append(pos_token_ids[:self.kd_online_topk])
|
current_target_token_ids.append(
|
||||||
|
pos_token_ids[: self.kd_online_topk]
|
||||||
|
)
|
||||||
|
|
||||||
# normalized_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk])
|
# normalized_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk])
|
||||||
# current_target_logprobs.append(normalized_logprobs_for_position)
|
# current_target_logprobs.append(normalized_logprobs_for_position)
|
||||||
# don't normalize for now as the probs seem to sum to 1.0 already
|
# don't normalize for now as the probs seem to sum to 1.0 already
|
||||||
current_target_logprobs.append(pos_logprobs_raw[:self.kd_online_topk])
|
current_target_logprobs.append(
|
||||||
|
pos_logprobs_raw[: self.kd_online_topk]
|
||||||
|
)
|
||||||
|
|
||||||
# Mask depends on the corresponding label for the student
|
# Mask depends on the corresponding label for the student
|
||||||
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
||||||
@@ -347,8 +440,12 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
current_target_mask.append([1] * self.kd_online_topk)
|
current_target_mask.append([1] * self.kd_online_topk)
|
||||||
else:
|
else:
|
||||||
# Pad if no logprobs for this position (either due to length mismatch or None entry)
|
# Pad if no logprobs for this position (either due to length mismatch or None entry)
|
||||||
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
|
current_target_logprobs.append(
|
||||||
current_target_token_ids.append(list(range(self.kd_online_topk)))
|
[-float("inf")] * self.kd_online_topk
|
||||||
|
)
|
||||||
|
current_target_token_ids.append(
|
||||||
|
list(range(self.kd_online_topk))
|
||||||
|
)
|
||||||
current_target_mask.append([0] * self.kd_online_topk)
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
|
|
||||||
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
|
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
|
||||||
@@ -364,18 +461,24 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
||||||
raise e
|
raise e
|
||||||
# ret_logprobs_data will be returned with empty lists, handled by the caller.
|
# ret_logprobs_data will be returned with empty lists, handled by the caller.
|
||||||
except Exception as e: # Catch other potential errors during processing
|
except Exception as e: # Catch other potential errors during processing
|
||||||
LOG.error(f"Unexpected error processing API response in fetch_online_logprobs: {e}", exc_info=True)
|
LOG.error(
|
||||||
|
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return ret_logprobs_data
|
return ret_logprobs_data
|
||||||
|
|
||||||
def __call__(self, features: List[List[Dict[str, Any]]],
|
def __call__(
|
||||||
return_tensors: Optional[str] = None) -> Dict[str, Any]:
|
self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
if not features:
|
if not features:
|
||||||
return super().__call__(features, return_tensors=return_tensors)
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
|
|
||||||
for sub_batch_features in features: # sub_batch_features is List[Dict[str, Any]]
|
for (
|
||||||
|
sub_batch_features
|
||||||
|
) in features: # sub_batch_features is List[Dict[str, Any]]
|
||||||
if not sub_batch_features:
|
if not sub_batch_features:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -386,7 +489,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
for item_dict in sub_batch_features:
|
for item_dict in sub_batch_features:
|
||||||
if not isinstance(item_dict, dict):
|
if not isinstance(item_dict, dict):
|
||||||
LOG.warning(f"Skipping non-dict item in sub_batch_features: {item_dict}")
|
LOG.warning(
|
||||||
|
f"Skipping non-dict item in sub_batch_features: {item_dict}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
current_input_ids = item_dict.get("input_ids")
|
current_input_ids = item_dict.get("input_ids")
|
||||||
@@ -394,8 +499,16 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
if current_input_ids is not None and current_labels is not None:
|
if current_input_ids is not None and current_labels is not None:
|
||||||
# Ensure input_ids and labels are lists of ints for JSON serialization
|
# Ensure input_ids and labels are lists of ints for JSON serialization
|
||||||
input_ids_list = current_input_ids.tolist() if hasattr(current_input_ids, "tolist") else list(current_input_ids)
|
input_ids_list = (
|
||||||
labels_list = current_labels.tolist() if hasattr(current_labels, "tolist") else list(current_labels)
|
current_input_ids.tolist()
|
||||||
|
if hasattr(current_input_ids, "tolist")
|
||||||
|
else list(current_input_ids)
|
||||||
|
)
|
||||||
|
labels_list = (
|
||||||
|
current_labels.tolist()
|
||||||
|
if hasattr(current_labels, "tolist")
|
||||||
|
else list(current_labels)
|
||||||
|
)
|
||||||
|
|
||||||
input_ids_for_api_call.append(input_ids_list)
|
input_ids_for_api_call.append(input_ids_list)
|
||||||
labels_for_api_call.append(labels_list)
|
labels_for_api_call.append(labels_list)
|
||||||
@@ -408,29 +521,46 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
item_dict.setdefault("target_mask", [])
|
item_dict.setdefault("target_mask", [])
|
||||||
|
|
||||||
# print(items_for_api_call)
|
# print(items_for_api_call)
|
||||||
if items_for_api_call: # Only call API if there's something to process
|
if items_for_api_call: # Only call API if there's something to process
|
||||||
if self.kd_online_server == "sglang":
|
if self.kd_online_server == "sglang":
|
||||||
api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(input_ids_for_api_call, labels_for_api_call)
|
api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(
|
||||||
|
input_ids_for_api_call, labels_for_api_call
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(input_ids_for_api_call, labels_for_api_call)
|
api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(
|
||||||
|
input_ids_for_api_call, labels_for_api_call
|
||||||
|
)
|
||||||
|
|
||||||
# api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask"
|
# api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask"
|
||||||
# Each value is a list, corresponding to items_for_api_call
|
# Each value is a list, corresponding to items_for_api_call
|
||||||
for i, item_to_update in enumerate(items_for_api_call):
|
for i, item_to_update in enumerate(items_for_api_call):
|
||||||
# TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly.
|
# TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly.
|
||||||
if api_responses_for_sub_batch and \
|
if api_responses_for_sub_batch and i < len(
|
||||||
i < len(api_responses_for_sub_batch): # Check bounds
|
api_responses_for_sub_batch["target_token_ids"]
|
||||||
assert len(api_responses_for_sub_batch["target_token_ids"][i]) == len(item_to_update["input_ids"])
|
): # Check bounds
|
||||||
assert len(api_responses_for_sub_batch["target_logprobs"][i]) == len(item_to_update["input_ids"])
|
assert len(
|
||||||
assert len(api_responses_for_sub_batch["target_mask"][i]) == len(item_to_update["labels"])
|
api_responses_for_sub_batch["target_token_ids"][i]
|
||||||
item_to_update["target_token_ids"] = api_responses_for_sub_batch["target_token_ids"][i]
|
) == len(item_to_update["input_ids"])
|
||||||
item_to_update["target_logprobs"] = api_responses_for_sub_batch["target_logprobs"][i]
|
assert len(
|
||||||
item_to_update["target_mask"] = api_responses_for_sub_batch["target_mask"][i]
|
api_responses_for_sub_batch["target_logprobs"][i]
|
||||||
|
) == len(item_to_update["input_ids"])
|
||||||
|
assert len(
|
||||||
|
api_responses_for_sub_batch["target_mask"][i]
|
||||||
|
) == len(item_to_update["labels"])
|
||||||
|
item_to_update["target_token_ids"] = (
|
||||||
|
api_responses_for_sub_batch["target_token_ids"][i]
|
||||||
|
)
|
||||||
|
item_to_update["target_logprobs"] = api_responses_for_sub_batch[
|
||||||
|
"target_logprobs"
|
||||||
|
][i]
|
||||||
|
item_to_update["target_mask"] = api_responses_for_sub_batch[
|
||||||
|
"target_mask"
|
||||||
|
][i]
|
||||||
else:
|
else:
|
||||||
# API call failed for this item, or response was shorter than expected.
|
# API call failed for this item, or response was shorter than expected.
|
||||||
# Ensure KD fields are initialized as empty lists.
|
# Ensure KD fields are initialized as empty lists.
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"Failed to get online KD data for an item in the batch (index {i}), or API response was too short. "
|
f" (index {i}), or API response was too short. "
|
||||||
f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}"
|
f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}"
|
||||||
)
|
)
|
||||||
item_to_update.setdefault("target_token_ids", [])
|
item_to_update.setdefault("target_token_ids", [])
|
||||||
|
|||||||
Reference in New Issue
Block a user