more KD updates

This commit is contained in:
Wing Lian
2025-05-31 08:26:29 -04:00
parent 73a84ad0dd
commit d55a51623f
6 changed files with 118 additions and 76 deletions

View File

@@ -49,6 +49,7 @@ class KDPlugin(BasePlugin):
"kd_alpha": cfg.kd_alpha,
"kd_temperature": cfg.kd_temperature,
"kd_beta": cfg.kd_beta,
"kd_normalize_topk": cfg.kd_normalize_topk,
}
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
@@ -72,6 +73,7 @@ class KDPlugin(BasePlugin):
"kd_temperature": cfg.kd_temperature,
"kd_online_server": cfg.kd_online_server,
"kd_online_timeout": cfg.kd_online_timeout,
"kd_normalize_topk": cfg.kd_normalize_topk,
}
if use_batch_sampler_collator:

View File

@@ -42,6 +42,9 @@ class KDArgs(BaseModel):
kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
kd_normalize_topk: bool | None = (
None # whether to normalize student logits during KD
)
# TODO online kd
kd_online_server_base_url: str | None = None
@@ -67,3 +70,6 @@ class KDTrainingArgsMixin:
kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
kd_normalize_topk: float | None = (
None # whether to normalize student logits during KD
)

View File

@@ -12,6 +12,7 @@ import torch
from orjson import orjson
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
from axolotl.integrations.kd.utils import normalize_logprobs
from axolotl.utils.data.utils import retry_on_request_exceptions
LOG = logging.getLogger(__name__)
@@ -58,6 +59,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
kd_online_server: Optional[str] = "vllm",
kd_online_timeout: Optional[int] = 120,
kd_cache_dir: Optional[str] = None,
kd_normalize_topk: Optional[bool] = True,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
@@ -78,6 +80,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
self.http_session = requests.Session()
self.kd_online_timeout = kd_online_timeout
self.kd_cache_dir = kd_cache_dir
self.kd_normalize_topk = kd_normalize_topk
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
"""
@@ -88,70 +91,15 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
)
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
# This should ideally be handled by the caller ensuring correct padding/truncation first
if len(raw_logprobs) != self.kd_online_topk:
# This case should be rare if pre-padding/truncation is done correctly
LOG.warning(
f"Logprobs length mismatch in _normalize_logprobs. "
f"Expected {self.kd_online_topk}, got {len(raw_logprobs)}. Will pad/truncate."
)
padded_logprobs = raw_logprobs[: self.kd_online_topk]
if len(padded_logprobs) < self.kd_online_topk:
padded_logprobs.extend(
[-float("inf")] * (self.kd_online_topk - len(padded_logprobs))
)
raw_logprobs = padded_logprobs
try:
position_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
# Convert logprobs at T_online to probabilities
# use log sum exp trick to avoid underflow
position_logprobs_lse = torch.logsumexp(
position_logprobs_tensor, dim=-1, keepdim=True
)
teacher_probs_t_online = torch.exp(
position_logprobs_tensor - position_logprobs_lse
)
# Normalize probabilities (sum to 1)
# This is important if the top-k from server aren't a full distribution
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=0, keepdim=True)
if teacher_probs_t_online_sum.item() > 1e-9:
teacher_probs_t_online = (
teacher_probs_t_online / teacher_probs_t_online_sum
)
else:
# If sum is zero, create uniform distribution to avoid NaN/Inf later
# This can happen if all raw_logprobs are -inf
if self.kd_online_topk > 0:
teacher_probs_t_online = (
torch.ones_like(teacher_probs_t_online) / self.kd_online_topk
)
# else: leave as is, will result in -inf logprobs
#
# teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum(
# dim=0, keepdim=True
# )
final_logprobs_tensor = torch.log(teacher_probs_t_online)
return final_logprobs_tensor.tolist()
except Exception as e: # pylint: disable=broad-exception-caught
LOG.error(
f"Error during online logprob scaling: {e}. Returning raw logprobs.",
exc_info=True,
)
# Fallback to (padded/truncated) raw logprobs if scaling fails
return raw_logprobs
raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist()
@retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_sglang(
self, batch_input_ids: List[List[int]], labels: List[List[int]]
):
"""
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
Fetches logprobs from an online teacher served by sglang for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys.
"""
api_endpoint = f"{self.kd_online_server_base_url}/generate"
@@ -267,10 +215,18 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
current_target_token_ids.append(
pos_token_ids[: self.kd_online_topk]
)
scaled_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(scaled_logprobs_for_position)
if self.kd_normalize_topk:
normalized_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(
normalized_logprobs_for_position
)
else:
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
# Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
@@ -442,12 +398,17 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
pos_token_ids[: self.kd_online_topk]
)
# normalized_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk])
# current_target_logprobs.append(normalized_logprobs_for_position)
# don't normalize for now as the probs seem to sum to 1.0 already
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
if self.kd_normalize_topk:
normalized_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(
normalized_logprobs_for_position
)
else:
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
# Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:

View File

@@ -8,6 +8,8 @@ from liger_kernel.chunked_loss.fused_linear_distillation import (
LigerFusedLinearDistillationBase,
)
from axolotl.integrations.kd.utils import normalize_logprobs
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
"""
@@ -21,6 +23,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
beta: float = 0.0,
normalize_topk: bool = True,
) -> torch.Tensor:
"""
Compute Top-K KL divergence loss for a chunk.
@@ -33,9 +36,11 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
0.0 for Forward KL (P_teacher || P_student).
1.0 for Reverse KL (P_student || P_teacher).
0.5 for Symmetric KL (average of Forward and Reverse).
normalize_topk: Whether to normalize the log probabilities
Returns:
Sum of KL divergence losses for the chunk.
"""
topk = target_token_ids_chunk.shape[-1]
student_logits_temp_scaled = ( # [chunk_size, vocab_size]
student_logits_temp_scaled.float()
)
@@ -56,6 +61,12 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
student_logits_topk_temp_scaled - student_lse
)
# we have the top-k student logprobs, normalize them
if normalize_topk:
student_logprobs_topk_temp_scaled = normalize_logprobs(
student_logprobs_topk_temp_scaled, topk
)
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
@@ -67,7 +78,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
# Student probabilities P_student from log P_student
student_probs_topk_valid = student_logprobs_topk_valid.exp()
kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
# kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
@@ -75,18 +86,33 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
# Here, target_logprobs_valid are log_softmax_teacher.
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
if beta < 1.0: # Contribution from Forward KL
if beta == 0.0: # Contribution from Forward KL
fwd_kl_per_token = teacher_probs_valid * (
target_logprobs_valid - student_logprobs_topk_valid
)
kd_loss_per_token += (1.0 - beta) * fwd_kl_per_token
if beta > 0.0: # Contribution from Reverse KL
kd_loss = fwd_kl_per_token.sum()
elif beta == 1.0: # Contribution from Reverse KL
rev_kl_per_token = student_probs_topk_valid * (
student_logprobs_topk_valid - target_logprobs_valid
)
kd_loss_per_token += beta * rev_kl_per_token
kd_loss = kd_loss_per_token.sum()
kd_loss = rev_kl_per_token.sum()
else:
# JSD - Jensen-Shannon Divergence / Symmetric
mean_probs = (
1 - beta
) * student_probs_topk_valid + beta * teacher_probs_valid
log_mean_probs = mean_probs.log()
student_kl = F.kl_div(
log_mean_probs,
student_logprobs_topk_valid,
reduction="sum",
log_target=True,
)
teacher_kl = F.kl_div(
log_mean_probs, target_logprobs_valid, reduction="sum", log_target=True
)
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
kd_loss = jsd_loss
return kd_loss
@@ -109,6 +135,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
compute_ce_loss: bool = True,
temperature: float = 1.0,
beta: float = 0.0,
normalize_topk: bool = True,
):
# Compute student logits for the chunk from hidden states and LM head
# student_input_chunk: [chunk_size, hidden_dim]
@@ -144,6 +171,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
target_logprobs_chunk,
target_mask_chunk,
beta=beta,
normalize_topk=normalize_topk,
)
return soft_loss, ce_loss
@@ -167,6 +195,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
compiled: bool = False,
chunk_size: int = 1024,
compute_ce_loss: bool = True,
normalize_topk: bool = True,
):
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
@@ -211,6 +240,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
compute_ce_loss=compute_ce_loss,
temperature=temperature,
beta=beta,
normalize_topk=normalize_topk,
)
def accumulate_chunk_grads(
@@ -311,7 +341,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
ctx.bias_was_none = student_lm_head_bias is None
ctx.orig_dims = (B, N, D, K)
# since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulated sum
# since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum
# we still need to scale the kd_loss by the temp^2
kd_loss_acc = kd_loss_acc * (temperature**2)
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
@@ -397,6 +427,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
compiled: bool = True,
chunk_size: int = 1024,
compute_ce_loss: bool = True,
normalize_topk: bool = True,
):
super().__init__()
if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
@@ -412,6 +443,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
self.compiled = compiled
self.chunk_size = chunk_size
self.compute_ce_loss = compute_ce_loss
self.normalize_topk = normalize_topk
if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
print(
@@ -449,4 +481,5 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
self.compiled,
self.chunk_size,
self.compute_ce_loss,
self.normalize_topk,
)

View File

@@ -35,6 +35,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
self.args.kd_temperature,
self.args.kd_beta,
compute_ce_loss=bool(self.args.kd_ce_alpha),
normalize_topk=self.args.kd_normalize_topk,
)
def _set_signature_columns_if_needed(self):

View File

@@ -0,0 +1,39 @@
"""Helper KD utils"""
import torch
from torch import FloatTensor
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
"""
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
"""
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
# This should ideally be handled by the caller ensuring correct padding/truncation first
if logprobs.shape[-1] != topk:
# pad last dimension of logprobs to match topk length with -inf
padding_len = topk - logprobs.shape[-1]
padding_tensor = torch.full(
(
*logprobs.shape[:-1],
padding_len,
), # Takes all dimensions of logprobs except the last, then appends padding_needed
float("-inf"),
dtype=logprobs.dtype,
device=logprobs.device,
)
logprobs = torch.cat((logprobs, padding_tensor), dim=-1)
# Convert logprobs at T_online to probabilities
# use log sum exp trick to avoid underflow
position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True)
teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse)
# Normalize probabilities (sum to 1)
# This is important if the top-k from server aren't a full distribution
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True)
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
final_logprobs_tensor = torch.log(teacher_probs_t_online)
return final_logprobs_tensor