online kd wip

This commit is contained in:
Wing Lian
2025-05-25 16:52:20 -04:00
parent a8d9fab635
commit b4e96ef12c
6 changed files with 472 additions and 41 deletions

View File

@@ -21,3 +21,32 @@ datasets:
```
An example dataset can be found at [`axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample`](https://huggingface.co/datasets/axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample)
## Online KD (sglang)
```bash
export UV_TORCH_BACKEND=cu124
uv venv sglang --python 3.11
source sglang/bin/activate
uv pip install --upgrade pip
uv pip install setuptools
uv pip install torch~=2.5.1 --index-url https://download.pytorch.org/whl/cu124
uv pip install sgl-kernel --force-reinstall --no-deps
uv pip install "sglang[all]>=0.4.2.post4" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/
```
## Online KD (vllm)
```bash
VLLM_USE_V1=0 vllm serve open-r1/OlympicCoder-32B --max-model-len 16400 --port 8888 --max-logprobs 128 --return-tokens-as-token-ids --tensor-parallel-size 8 --max-num-seqs
256 --gpu_memory_utilization 0.2 --enable-chunked-prefill
```
```bash
vllm serve open-r1/OlympicCoder-32B --max-model-len 16400 --port 8888 --max-logprobs 128 --return-tokens-as-token-ids --tensor-parallel-size 8 --no-enable-prefix-caching --gpu-memory-utilization 0.3 --max-num-batched-tokens 131072 --host 0.0.0.0
```
```bash
python -m sglang.launch_server --model-path open-r1/OlympicCoder-32B --tensor-parallel-size 8 --port 8080 --host 0.0.0.0 --max-running-requests 256 --context-length 16400 --mem-fraction-static 0.2 --schedule-conservativeness 0.3 --chunked-prefill-size 131072 --schedule-policy fcfs --skip-tokenizer-init
```

View File

@@ -47,6 +47,16 @@ class KDPlugin(BasePlugin):
if cfg.eval_sample_packing and is_eval:
use_batch_sampler_collator = True
if cfg.kd_online_server_base_url:
from .collator_online_teacher import OnlineTeacherCollator
return OnlineTeacherCollator, {
"kd_online_server_base_url": cfg.kd_online_server_base_url,
"kd_online_topk": cfg.kd_online_topk,
"kd_temperature": cfg.kd_temperature,
"kd_online_server": cfg.kd_online_server,
}
if use_batch_sampler_collator:
return KDBatchSamplerDataCollatorForSeq2Seq, {}
return DataCollatorForKD, {}

View File

@@ -15,9 +15,16 @@
"""
Plugin args for KD support.
"""
from enum import Enum
from pydantic import BaseModel
class InferenceServerType(str, Enum):
vllm = "vllm"
sglang = "sglang"
class KDArgs(BaseModel):
"""
Input args for knowledge distillation.
@@ -31,5 +38,6 @@ class KDArgs(BaseModel):
kd_temperature: float | None = None # temperature for sampling during KD
# TODO online kd
# kd_online_server_base_url: str | None = None
# kd_online_topk: int | None = None
kd_online_server_base_url: str | None = None
kd_online_topk: int | None = None
kd_online_server: InferenceServerType | None = "vllm"

View File

@@ -0,0 +1,422 @@
import requests
import logging
from typing import List, Optional, Dict, Any
import torch
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import retry_on_request_exceptions
LOG = logging.getLogger(__name__)
class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
def __init__(
self,
*args: Any,
kd_online_server_base_url: Optional[str] = None,
kd_online_topk: Optional[int] = None,
kd_temperature: Optional[float] = 1.0,
kd_online_server: Optional[str] = "vllm",
**kwargs: Any,
):
super().__init__(*args, **kwargs)
if kd_online_server_base_url is None:
raise ValueError("kd_online_server_base_url must be provided for OnlineTeacherDataloader")
if kd_online_topk is None or kd_online_topk <= 0:
raise ValueError("kd_online_topk must be a positive integer for OnlineTeacherDataloader")
self.kd_online_server_base_url = kd_online_server_base_url.rstrip('/')
self.kd_online_topk = kd_online_topk
self.kd_temperature = kd_temperature
self.kd_online_server = kd_online_server
self.http_session = requests.Session()
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
"""
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
"""
if not raw_logprobs or self.kd_online_topk == 0:
return [-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
# This should ideally be handled by the caller ensuring correct padding/truncation first
if len(raw_logprobs) != self.kd_online_topk:
# This case should be rare if pre-padding/truncation is done correctly
LOG.warning(
f"Logprobs length mismatch in _normalize_logprobs. "
f"Expected {self.kd_online_topk}, got {len(raw_logprobs)}. Will pad/truncate."
)
padded_logprobs = raw_logprobs[:self.kd_online_topk]
if len(padded_logprobs) < self.kd_online_topk:
padded_logprobs.extend([-float("inf")] * (self.kd_online_topk - len(padded_logprobs)))
raw_logprobs = padded_logprobs
try:
position_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
# Convert logprobs at T_online to probabilities
teacher_probs_t_online = position_logprobs_tensor.exp()
# Normalize probabilities (sum to 1)
# This is important if the top-k from server aren't a full distribution
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=0, keepdim=True)
if teacher_probs_t_online_sum.item() > 1e-9:
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
else:
# If sum is zero, create uniform distribution to avoid NaN/Inf later
# This can happen if all raw_logprobs are -inf
if self.kd_online_topk > 0:
teacher_probs_t_online = torch.ones_like(teacher_probs_t_online) / self.kd_online_topk
# else: leave as is, will result in -inf logprobs
#
# teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum(
# dim=0, keepdim=True
# )
# TODO Convert back to logprobs using log sum exp trick
teacher_probs_t_online_max = torch.logsumexp(teacher_probs_t_online, dim=-1, keepdim=True)
final_logprobs_tensor = teacher_probs_t_online - teacher_probs_t_online_max
# final_logprobs_tensor = torch.log(teacher_probs_t_online)
return final_logprobs_tensor.tolist()
except Exception as e:
LOG.error(f"Error during online logprob scaling: {e}. Returning raw logprobs.", exc_info=True)
# Fallback to (padded/truncated) raw logprobs if scaling fails
return raw_logprobs
@retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_sglang(self, batch_input_ids: List[List[int]], labels: List[List[int]]):
"""
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys.
"""
api_endpoint = f"{self.kd_online_server_base_url}/generate"
payload = {
"input_ids": batch_input_ids,
"return_logprob": True,
"top_logprobs_num": self.kd_online_topk,
"logprob_start_len": 0,
"return_text_in_logprobs": True,
"echo": True,
"sampling_params": {
"max_new_tokens": 0,
"temperature": self.kd_temperature,
"skip_special_tokens": False,
},
}
# Initialize with empty lists, so if API call fails, these are returned.
ret_logprobs_data = {
"target_token_ids": [],
"target_logprobs": [],
"target_mask": [],
}
try:
response = self.http_session.post(api_endpoint, json=payload, timeout=60)
response.raise_for_status()
api_data: list[dict] = response.json()
# Ensure api_data is a list, and its length matches batch_input_ids
if not isinstance(api_data, list) or len(api_data) != len(batch_input_ids):
LOG.error(
f"API response format error. Expected a list of {len(batch_input_ids)} "
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
)
# Return empty data; items processed later will get default empty KD fields
return ret_logprobs_data
for sequence_data, seq_input_ids, seq_labels in zip(api_data, batch_input_ids, labels):
current_target_logprobs = []
current_target_token_ids = []
current_target_mask = []
meta_info = sequence_data.pop("meta_info", {})
# Ensure input_top_logprobs is a list
input_top_logprobs: Optional[list[None |list[tuple]]] = meta_info.pop("input_top_logprobs", [])
if not isinstance(input_top_logprobs, list):
LOG.warning(f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.")
input_top_logprobs = [] # Treat as empty
# basic check that the logprob data len matches the input len, so no need to handle padding
assert len(seq_input_ids) == len(input_top_logprobs)
for i, input_id, label in zip(range(len(seq_input_ids)), seq_input_ids, seq_labels):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token.
# there is never logprob data for the first token since that's a true input
# so we replace the None value with padding data
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
elif i < len(input_top_logprobs) and input_top_logprobs[i] is not None:
pos_top_logprobs_data = input_top_logprobs[i]
# Ensure pos_top_logprobs_data is a list of lists as expected
if not (isinstance(pos_top_logprobs_data, list) and \
all(isinstance(item, list) for item in pos_top_logprobs_data) and \
len(pos_top_logprobs_data) > 0 and len(pos_top_logprobs_data[0]) == 3): # [logprob, token_id, token_str]
LOG.warning(f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.")
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
continue
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
pos_logprobs_raw, pos_token_ids, _ = [list(row) for row in zip(*pos_top_logprobs_data)]
# Ensure correct length (top_k)
if len(pos_logprobs_raw) < self.kd_online_topk:
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
pos_logprobs_raw.extend([-float("inf")] * pad_len)
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
# truncate to top_k in case the response was longer
current_target_token_ids.append(pos_token_ids[:self.kd_online_topk])
scaled_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk])
current_target_logprobs.append(scaled_logprobs_for_position)
# Mask depends on the corresponding label for the student
label_for_pos = seq_labels[i] if i < len(seq_labels) else self.DEFAULT_LABEL_PAD_TOKEN_ID
if label_for_pos == self.DEFAULT_LABEL_PAD_TOKEN_ID:
current_target_mask.append([0] * self.kd_online_topk)
else:
current_target_mask.append([1] * self.kd_online_topk)
else:
# Pad if no logprobs for this position (either due to length mismatch or None entry)
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
ret_logprobs_data["target_mask"].append(current_target_mask)
except requests.exceptions.RequestException as e:
LOG.error(f"Error fetching logprobs from online teacher: {e}")
raise e
# ret_logprobs_data will be returned with empty lists, handled by the caller.
except Exception as e: # Catch other potential errors during processing
LOG.error(f"Unexpected error processing API response in fetch_online_logprobs: {e}", exc_info=True)
raise e
# Return initialized empty data
# return {
# "target_token_ids": [],
# "target_logprobs": [],
# "target_mask": [],
# }
return ret_logprobs_data
@retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_vllm(self, batch_input_ids: List[List[int]], labels: List[List[int]]):
"""
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys.
"""
api_endpoint = f"{self.kd_online_server_base_url}/v1/completions"
payload = {
"prompt": batch_input_ids,
"echo": True,
"logprobs": True,
"prompt_logprobs": self.kd_online_topk,
"top_logprobs": self.kd_online_topk,
"max_new_tokens": 0,
"skip_special_tokens": False,
"temperature": self.kd_temperature,
}
# Initialize with empty lists, so if API call fails, these are returned.
ret_logprobs_data = {
"target_token_ids": [],
"target_logprobs": [],
"target_mask": [],
}
try:
response = self.http_session.post(api_endpoint, json=payload, timeout=60)
response.raise_for_status()
api_data: dict = response.json()
choices: list[dict] = api_data["choices"]
# Ensure api_data is a list, and its length matches batch_input_ids
if not isinstance(choices, list) or len(choices) != len(batch_input_ids):
LOG.error(
f"API response format error. Expected a list of {len(batch_input_ids)} "
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
)
# Return empty data; items processed later will get default empty KD fields
return ret_logprobs_data
for sequence_data, seq_input_ids, seq_labels in zip(choices, batch_input_ids, labels):
current_target_logprobs = []
current_target_token_ids = []
current_target_mask = []
# Ensure input_top_logprobs is a list
input_top_logprobs: Optional[list[None | list[tuple]]] = sequence_data.pop("prompt_logprobs", [])
"""
vllm api data for prompt logprobs looks like:
"prompt_logprobs": [
null, # first token is always null
{ # second token logprobs
"8948": { # token ID
"logprob": -2.3841830625315197e-06,
"rank": 1,
"decoded_token": "system"
},
"1849": { # token ID
"logprob": -13.187501907348633,
"rank": 2,
"decoded_token": "Ġsystem"
},
... # rest of the top-k tokens/logprobs
},
... # more tokens
}
"""
if not isinstance(input_top_logprobs, list):
LOG.warning(f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.")
input_top_logprobs = [] # Treat as empty
for i, input_id, label in zip(range(len(seq_input_ids)), seq_input_ids, seq_labels):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token.
# there is never logprob data for the first token since that's a true input
# so we replace the None value with padding data
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
elif i < len(input_top_logprobs) and input_top_logprobs[i] is not None:
pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i]
# Ensure pos_top_logprobs_data is a list of lists as expected
if not (isinstance(pos_top_logprobs_data, dict) and \
all(isinstance(item, dict) for item in pos_top_logprobs_data.values()) and \
len(pos_top_logprobs_data.keys()) > 0): # [logprob, token_id, token_str]
LOG.warning(f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.")
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
continue
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
pos_token_ids = pos_top_logprobs_data.keys()
pos_logprobs_dict = pos_top_logprobs_data.values()
pos_token_ids = [int(token_id) for token_id in pos_token_ids]
pos_logprobs_raw = [float(logprob.get("logprob", -float("inf"))) for logprob in pos_logprobs_dict]
# Ensure correct length (top_k)
if len(pos_logprobs_raw) < self.kd_online_topk:
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
LOG.debug(f"Padding position {i} with {pad_len} top-k tokens and logprobs.")
pos_logprobs_raw.extend([-float("inf")] * pad_len)
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
# truncate to top_k in case the response was longer
current_target_token_ids.append(pos_token_ids[:self.kd_online_topk])
# normalized_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk])
# current_target_logprobs.append(normalized_logprobs_for_position)
# don't normalize for now as the probs seem to sum to 1.0 already
current_target_logprobs.append(pos_logprobs_raw[:self.kd_online_topk])
# Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
current_target_mask.append([0] * self.kd_online_topk)
else:
current_target_mask.append([1] * self.kd_online_topk)
else:
# Pad if no logprobs for this position (either due to length mismatch or None entry)
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
ret_logprobs_data["target_mask"].append(current_target_mask)
except requests.exceptions.RequestException as e:
LOG.error(f"Error fetching logprobs from online teacher: {e}")
raise e
# ret_logprobs_data will be returned with empty lists, handled by the caller.
except Exception as e: # Catch other potential errors during processing
LOG.error(f"Unexpected error processing API response in fetch_online_logprobs: {e}", exc_info=True)
raise e
return ret_logprobs_data
def __call__(self, features: List[List[Dict[str, Any]]],
return_tensors: Optional[str] = None) -> Dict[str, Any]:
if not features:
return super().__call__(features, return_tensors=return_tensors)
for sub_batch_features in features: # sub_batch_features is List[Dict[str, Any]]
if not sub_batch_features:
continue
input_ids_for_api_call: List[List[int]] = []
labels_for_api_call: List[List[int]] = []
# Store references to the original item dictionaries to update them in-place
items_for_api_call: List[Dict[str, Any]] = []
for item_dict in sub_batch_features:
if not isinstance(item_dict, dict):
LOG.warning(f"Skipping non-dict item in sub_batch_features: {item_dict}")
continue
current_input_ids = item_dict.get("input_ids")
current_labels = item_dict.get("labels")
if current_input_ids is not None and current_labels is not None:
# Ensure input_ids and labels are lists of ints for JSON serialization
input_ids_list = current_input_ids.tolist() if hasattr(current_input_ids, "tolist") else list(current_input_ids)
labels_list = current_labels.tolist() if hasattr(current_labels, "tolist") else list(current_labels)
input_ids_for_api_call.append(input_ids_list)
labels_for_api_call.append(labels_list)
items_for_api_call.append(item_dict)
else:
# This item will not get teacher logprobs from the API.
# Initialize KD fields to empty lists so downstream collators handle them uniformly.
item_dict.setdefault("target_token_ids", [])
item_dict.setdefault("target_logprobs", [])
item_dict.setdefault("target_mask", [])
# print(items_for_api_call)
if items_for_api_call: # Only call API if there's something to process
if self.kd_online_server == "sglang":
api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(input_ids_for_api_call, labels_for_api_call)
else:
api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(input_ids_for_api_call, labels_for_api_call)
# api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask"
# Each value is a list, corresponding to items_for_api_call
for i, item_to_update in enumerate(items_for_api_call):
# Check if API call was successful and returned data for this item.
# fetch_online_logprobs returns dict with empty lists if API fails or malformed.
# TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly.
if api_responses_for_sub_batch and \
i < len(api_responses_for_sub_batch.get("target_token_ids", [])): # Check bounds
item_to_update["target_token_ids"] = api_responses_for_sub_batch["target_token_ids"][i]
item_to_update["target_logprobs"] = api_responses_for_sub_batch["target_logprobs"][i]
item_to_update["target_mask"] = api_responses_for_sub_batch["target_mask"][i]
else:
# API call failed for this item, or response was shorter than expected.
# Ensure KD fields are initialized as empty lists.
LOG.warning(
f"Failed to get online KD data for an item in the batch (index {i}), or API response was too short. "
f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}"
)
item_to_update.setdefault("target_token_ids", [])
item_to_update.setdefault("target_logprobs", [])
item_to_update.setdefault("target_mask", [])
return super().__call__(features, return_tensors=return_tensors)

View File

@@ -62,12 +62,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
Subclass and override for custom behavior.
"""
# target_logprobs = inputs.pop("target_logprobs")
# target_token_ids = inputs.pop("target_token_ids")
# target_mask = inputs.pop("target_mask")
# seq_len = target_token_ids.shape[1]
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
@@ -75,36 +69,3 @@ class AxolotlKDTrainer(AxolotlTrainer):
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
return outputs[0]
#
# # FIXME: account for tokenizer.padding_side
# student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
#
# shift_logits = student_logits.contiguous()
# target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
# target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
# target_mask_for_loss = target_mask[..., 1:, :].contiguous()
#
# loss_kd = self.kd_loss_fn(
# shift_logits,
# target_token_ids_for_loss,
# target_logprobs_for_loss,
# target_mask_for_loss,
# num_items_in_batch=num_items_in_batch,
# )
#
# if self.args.kd_ce_alpha > 0:
# kd_alpha = self.args.kd_alpha
# loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
# else:
# loss = loss_kd
# # Save past state if it exists
# # TODO: this needs to be fixed and made cleaner later.
# if self.args.past_index >= 0:
# self._past = outputs[ # pylint: disable=attribute-defined-outside-init
# self.args.past_index
# ]
#
# if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
# loss *= self.accelerator.num_processes
# return (loss, outputs) if return_outputs else loss

View File

@@ -40,6 +40,7 @@ def retry_on_request_exceptions(
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
requests.exceptions.HTTPError,
huggingface_hub.errors.HfHubHTTPError,
) as exc:
if attempt < max_retries - 1: