more KD updates

This commit is contained in:
Wing Lian
2025-05-31 08:26:29 -04:00
parent 73a84ad0dd
commit d55a51623f
6 changed files with 118 additions and 76 deletions

View File

@@ -49,6 +49,7 @@ class KDPlugin(BasePlugin):
"kd_alpha": cfg.kd_alpha, "kd_alpha": cfg.kd_alpha,
"kd_temperature": cfg.kd_temperature, "kd_temperature": cfg.kd_temperature,
"kd_beta": cfg.kd_beta, "kd_beta": cfg.kd_beta,
"kd_normalize_topk": cfg.kd_normalize_topk,
} }
def get_collator_cls_and_kwargs(self, cfg, is_eval=False): def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
@@ -72,6 +73,7 @@ class KDPlugin(BasePlugin):
"kd_temperature": cfg.kd_temperature, "kd_temperature": cfg.kd_temperature,
"kd_online_server": cfg.kd_online_server, "kd_online_server": cfg.kd_online_server,
"kd_online_timeout": cfg.kd_online_timeout, "kd_online_timeout": cfg.kd_online_timeout,
"kd_normalize_topk": cfg.kd_normalize_topk,
} }
if use_batch_sampler_collator: if use_batch_sampler_collator:

View File

@@ -42,6 +42,9 @@ class KDArgs(BaseModel):
kd_alpha: float | None = None # loss coefficient for KD loss kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD kd_temperature: float | None = None # temperature for sampling during KD
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
kd_normalize_topk: bool | None = (
None # whether to normalize student logits during KD
)
# TODO online kd # TODO online kd
kd_online_server_base_url: str | None = None kd_online_server_base_url: str | None = None
@@ -67,3 +70,6 @@ class KDTrainingArgsMixin:
kd_alpha: float | None = None # loss coefficient for KD loss kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD kd_temperature: float | None = None # temperature for sampling during KD
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
kd_normalize_topk: float | None = (
None # whether to normalize student logits during KD
)

View File

@@ -12,6 +12,7 @@ import torch
from orjson import orjson from orjson import orjson
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
from axolotl.integrations.kd.utils import normalize_logprobs
from axolotl.utils.data.utils import retry_on_request_exceptions from axolotl.utils.data.utils import retry_on_request_exceptions
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -58,6 +59,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
kd_online_server: Optional[str] = "vllm", kd_online_server: Optional[str] = "vllm",
kd_online_timeout: Optional[int] = 120, kd_online_timeout: Optional[int] = 120,
kd_cache_dir: Optional[str] = None, kd_cache_dir: Optional[str] = None,
kd_normalize_topk: Optional[bool] = True,
**kwargs: Any, **kwargs: Any,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -78,6 +80,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
self.http_session = requests.Session() self.http_session = requests.Session()
self.kd_online_timeout = kd_online_timeout self.kd_online_timeout = kd_online_timeout
self.kd_cache_dir = kd_cache_dir self.kd_cache_dir = kd_cache_dir
self.kd_normalize_topk = kd_normalize_topk
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]: def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
""" """
@@ -88,70 +91,15 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else [] [-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
) )
# Ensure raw_logprobs matches kd_online_topk length for tensor operations raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
# This should ideally be handled by the caller ensuring correct padding/truncation first return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist()
if len(raw_logprobs) != self.kd_online_topk:
# This case should be rare if pre-padding/truncation is done correctly
LOG.warning(
f"Logprobs length mismatch in _normalize_logprobs. "
f"Expected {self.kd_online_topk}, got {len(raw_logprobs)}. Will pad/truncate."
)
padded_logprobs = raw_logprobs[: self.kd_online_topk]
if len(padded_logprobs) < self.kd_online_topk:
padded_logprobs.extend(
[-float("inf")] * (self.kd_online_topk - len(padded_logprobs))
)
raw_logprobs = padded_logprobs
try:
position_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
# Convert logprobs at T_online to probabilities
# use log sum exp trick to avoid underflow
position_logprobs_lse = torch.logsumexp(
position_logprobs_tensor, dim=-1, keepdim=True
)
teacher_probs_t_online = torch.exp(
position_logprobs_tensor - position_logprobs_lse
)
# Normalize probabilities (sum to 1)
# This is important if the top-k from server aren't a full distribution
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=0, keepdim=True)
if teacher_probs_t_online_sum.item() > 1e-9:
teacher_probs_t_online = (
teacher_probs_t_online / teacher_probs_t_online_sum
)
else:
# If sum is zero, create uniform distribution to avoid NaN/Inf later
# This can happen if all raw_logprobs are -inf
if self.kd_online_topk > 0:
teacher_probs_t_online = (
torch.ones_like(teacher_probs_t_online) / self.kd_online_topk
)
# else: leave as is, will result in -inf logprobs
#
# teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum(
# dim=0, keepdim=True
# )
final_logprobs_tensor = torch.log(teacher_probs_t_online)
return final_logprobs_tensor.tolist()
except Exception as e: # pylint: disable=broad-exception-caught
LOG.error(
f"Error during online logprob scaling: {e}. Returning raw logprobs.",
exc_info=True,
)
# Fallback to (padded/truncated) raw logprobs if scaling fails
return raw_logprobs
@retry_on_request_exceptions(max_retries=10, delay=5) @retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_sglang( def fetch_online_logprobs_sglang(
self, batch_input_ids: List[List[int]], labels: List[List[int]] self, batch_input_ids: List[List[int]], labels: List[List[int]]
): ):
""" """
Fetches logprobs from an online teacher served by vllm for a batch of input_ids. Fetches logprobs from an online teacher served by sglang for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys. Assumes API returns token IDs as strings in logprob dictionary keys.
""" """
api_endpoint = f"{self.kd_online_server_base_url}/generate" api_endpoint = f"{self.kd_online_server_base_url}/generate"
@@ -267,10 +215,18 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
current_target_token_ids.append( current_target_token_ids.append(
pos_token_ids[: self.kd_online_topk] pos_token_ids[: self.kd_online_topk]
) )
scaled_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk] if self.kd_normalize_topk:
) normalized_logprobs_for_position = self._normalize_logprobs(
current_target_logprobs.append(scaled_logprobs_for_position) pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(
normalized_logprobs_for_position
)
else:
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
# Mask depends on the corresponding label for the student # Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID: if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
@@ -442,12 +398,17 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
pos_token_ids[: self.kd_online_topk] pos_token_ids[: self.kd_online_topk]
) )
# normalized_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk]) if self.kd_normalize_topk:
# current_target_logprobs.append(normalized_logprobs_for_position) normalized_logprobs_for_position = self._normalize_logprobs(
# don't normalize for now as the probs seem to sum to 1.0 already pos_logprobs_raw[: self.kd_online_topk]
current_target_logprobs.append( )
pos_logprobs_raw[: self.kd_online_topk] current_target_logprobs.append(
) normalized_logprobs_for_position
)
else:
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
# Mask depends on the corresponding label for the student # Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID: if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:

View File

@@ -8,6 +8,8 @@ from liger_kernel.chunked_loss.fused_linear_distillation import (
LigerFusedLinearDistillationBase, LigerFusedLinearDistillationBase,
) )
from axolotl.integrations.kd.utils import normalize_logprobs
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
""" """
@@ -21,6 +23,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
target_mask_chunk: torch.Tensor, # [chunk_size, top_k] target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
beta: float = 0.0, beta: float = 0.0,
normalize_topk: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute Top-K KL divergence loss for a chunk. Compute Top-K KL divergence loss for a chunk.
@@ -33,9 +36,11 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
0.0 for Forward KL (P_teacher || P_student). 0.0 for Forward KL (P_teacher || P_student).
1.0 for Reverse KL (P_student || P_teacher). 1.0 for Reverse KL (P_student || P_teacher).
0.5 for Symmetric KL (average of Forward and Reverse). 0.5 for Symmetric KL (average of Forward and Reverse).
normalize_topk: Whether to normalize the log probabilities
Returns: Returns:
Sum of KL divergence losses for the chunk. Sum of KL divergence losses for the chunk.
""" """
topk = target_token_ids_chunk.shape[-1]
student_logits_temp_scaled = ( # [chunk_size, vocab_size] student_logits_temp_scaled = ( # [chunk_size, vocab_size]
student_logits_temp_scaled.float() student_logits_temp_scaled.float()
) )
@@ -56,6 +61,12 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
student_logits_topk_temp_scaled - student_lse student_logits_topk_temp_scaled - student_lse
) )
# we have the top-k student logprobs, normalize them
if normalize_topk:
student_logprobs_topk_temp_scaled = normalize_logprobs(
student_logprobs_topk_temp_scaled, topk
)
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k] valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask] student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
@@ -67,7 +78,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
# Student probabilities P_student from log P_student # Student probabilities P_student from log P_student
student_probs_topk_valid = student_logprobs_topk_valid.exp() student_probs_topk_valid = student_logprobs_topk_valid.exp()
kd_loss_per_token = torch.zeros_like(target_logprobs_valid) # kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
# KL divergence: sum(P_teacher * (log P_teacher - log P_student)) # KL divergence: sum(P_teacher * (log P_teacher - log P_student))
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student) # = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
@@ -75,18 +86,33 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student)) # or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
# Here, target_logprobs_valid are log_softmax_teacher. # Here, target_logprobs_valid are log_softmax_teacher.
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices). # student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
if beta < 1.0: # Contribution from Forward KL if beta == 0.0: # Contribution from Forward KL
fwd_kl_per_token = teacher_probs_valid * ( fwd_kl_per_token = teacher_probs_valid * (
target_logprobs_valid - student_logprobs_topk_valid target_logprobs_valid - student_logprobs_topk_valid
) )
kd_loss_per_token += (1.0 - beta) * fwd_kl_per_token kd_loss = fwd_kl_per_token.sum()
if beta > 0.0: # Contribution from Reverse KL elif beta == 1.0: # Contribution from Reverse KL
rev_kl_per_token = student_probs_topk_valid * ( rev_kl_per_token = student_probs_topk_valid * (
student_logprobs_topk_valid - target_logprobs_valid student_logprobs_topk_valid - target_logprobs_valid
) )
kd_loss_per_token += beta * rev_kl_per_token kd_loss = rev_kl_per_token.sum()
else:
kd_loss = kd_loss_per_token.sum() # JSD - Jensen-Shannon Divergence / Symmetric
mean_probs = (
1 - beta
) * student_probs_topk_valid + beta * teacher_probs_valid
log_mean_probs = mean_probs.log()
student_kl = F.kl_div(
log_mean_probs,
student_logprobs_topk_valid,
reduction="sum",
log_target=True,
)
teacher_kl = F.kl_div(
log_mean_probs, target_logprobs_valid, reduction="sum", log_target=True
)
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
kd_loss = jsd_loss
return kd_loss return kd_loss
@@ -109,6 +135,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
compute_ce_loss: bool = True, compute_ce_loss: bool = True,
temperature: float = 1.0, temperature: float = 1.0,
beta: float = 0.0, beta: float = 0.0,
normalize_topk: bool = True,
): ):
# Compute student logits for the chunk from hidden states and LM head # Compute student logits for the chunk from hidden states and LM head
# student_input_chunk: [chunk_size, hidden_dim] # student_input_chunk: [chunk_size, hidden_dim]
@@ -144,6 +171,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
target_logprobs_chunk, target_logprobs_chunk,
target_mask_chunk, target_mask_chunk,
beta=beta, beta=beta,
normalize_topk=normalize_topk,
) )
return soft_loss, ce_loss return soft_loss, ce_loss
@@ -167,6 +195,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
compiled: bool = False, compiled: bool = False,
chunk_size: int = 1024, chunk_size: int = 1024,
compute_ce_loss: bool = True, compute_ce_loss: bool = True,
normalize_topk: bool = True,
): ):
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
grad_weight_acc = torch.zeros_like(student_lm_head_weight) grad_weight_acc = torch.zeros_like(student_lm_head_weight)
@@ -211,6 +240,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
compute_ce_loss=compute_ce_loss, compute_ce_loss=compute_ce_loss,
temperature=temperature, temperature=temperature,
beta=beta, beta=beta,
normalize_topk=normalize_topk,
) )
def accumulate_chunk_grads( def accumulate_chunk_grads(
@@ -311,7 +341,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
ctx.bias_was_none = student_lm_head_bias is None ctx.bias_was_none = student_lm_head_bias is None
ctx.orig_dims = (B, N, D, K) ctx.orig_dims = (B, N, D, K)
# since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulated sum # since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum
# we still need to scale the kd_loss by the temp^2 # we still need to scale the kd_loss by the temp^2
kd_loss_acc = kd_loss_acc * (temperature**2) kd_loss_acc = kd_loss_acc * (temperature**2)
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
@@ -397,6 +427,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
compiled: bool = True, compiled: bool = True,
chunk_size: int = 1024, chunk_size: int = 1024,
compute_ce_loss: bool = True, compute_ce_loss: bool = True,
normalize_topk: bool = True,
): ):
super().__init__() super().__init__()
if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0): if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
@@ -412,6 +443,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
self.compiled = compiled self.compiled = compiled
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.compute_ce_loss = compute_ce_loss self.compute_ce_loss = compute_ce_loss
self.normalize_topk = normalize_topk
if not self.compute_ce_loss and self.weight_hard_loss > 0.0: if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
print( print(
@@ -449,4 +481,5 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
self.compiled, self.compiled,
self.chunk_size, self.chunk_size,
self.compute_ce_loss, self.compute_ce_loss,
self.normalize_topk,
) )

View File

@@ -35,6 +35,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
self.args.kd_temperature, self.args.kd_temperature,
self.args.kd_beta, self.args.kd_beta,
compute_ce_loss=bool(self.args.kd_ce_alpha), compute_ce_loss=bool(self.args.kd_ce_alpha),
normalize_topk=self.args.kd_normalize_topk,
) )
def _set_signature_columns_if_needed(self): def _set_signature_columns_if_needed(self):

View File

@@ -0,0 +1,39 @@
"""Helper KD utils"""
import torch
from torch import FloatTensor
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
"""
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
"""
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
# This should ideally be handled by the caller ensuring correct padding/truncation first
if logprobs.shape[-1] != topk:
# pad last dimension of logprobs to match topk length with -inf
padding_len = topk - logprobs.shape[-1]
padding_tensor = torch.full(
(
*logprobs.shape[:-1],
padding_len,
), # Takes all dimensions of logprobs except the last, then appends padding_needed
float("-inf"),
dtype=logprobs.dtype,
device=logprobs.device,
)
logprobs = torch.cat((logprobs, padding_tensor), dim=-1)
# Convert logprobs at T_online to probabilities
# use log sum exp trick to avoid underflow
position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True)
teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse)
# Normalize probabilities (sum to 1)
# This is important if the top-k from server aren't a full distribution
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True)
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
final_logprobs_tensor = torch.log(teacher_probs_t_online)
return final_logprobs_tensor