From d55a51623fc1bcddd63d6b1ccf41a1eb813ddb75 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 31 May 2025 08:26:29 -0400 Subject: [PATCH] more KD updates --- src/axolotl/integrations/kd/__init__.py | 2 + src/axolotl/integrations/kd/args.py | 6 ++ .../kd/collator_online_teacher.py | 97 ++++++------------- src/axolotl/integrations/kd/kernels/liger.py | 49 ++++++++-- src/axolotl/integrations/kd/trainer.py | 1 + src/axolotl/integrations/kd/utils.py | 39 ++++++++ 6 files changed, 118 insertions(+), 76 deletions(-) create mode 100644 src/axolotl/integrations/kd/utils.py diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py index b53c62f8a..4c8535a0a 100644 --- a/src/axolotl/integrations/kd/__init__.py +++ b/src/axolotl/integrations/kd/__init__.py @@ -49,6 +49,7 @@ class KDPlugin(BasePlugin): "kd_alpha": cfg.kd_alpha, "kd_temperature": cfg.kd_temperature, "kd_beta": cfg.kd_beta, + "kd_normalize_topk": cfg.kd_normalize_topk, } def get_collator_cls_and_kwargs(self, cfg, is_eval=False): @@ -72,6 +73,7 @@ class KDPlugin(BasePlugin): "kd_temperature": cfg.kd_temperature, "kd_online_server": cfg.kd_online_server, "kd_online_timeout": cfg.kd_online_timeout, + "kd_normalize_topk": cfg.kd_normalize_topk, } if use_batch_sampler_collator: diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index 8b6d6b6f5..96ac2300d 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -42,6 +42,9 @@ class KDArgs(BaseModel): kd_alpha: float | None = None # loss coefficient for KD loss kd_temperature: float | None = None # temperature for sampling during KD kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL + kd_normalize_topk: bool | None = ( + None # whether to normalize student logits during KD + ) # TODO online kd kd_online_server_base_url: str | None = None @@ -67,3 +70,6 @@ class KDTrainingArgsMixin: kd_alpha: float | None = None # loss coefficient for KD loss kd_temperature: float | None = None # temperature for sampling during KD kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL + kd_normalize_topk: float | None = ( + None # whether to normalize student logits during KD + ) diff --git a/src/axolotl/integrations/kd/collator_online_teacher.py b/src/axolotl/integrations/kd/collator_online_teacher.py index ecd0a0c01..584ace481 100644 --- a/src/axolotl/integrations/kd/collator_online_teacher.py +++ b/src/axolotl/integrations/kd/collator_online_teacher.py @@ -12,6 +12,7 @@ import torch from orjson import orjson from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq +from axolotl.integrations.kd.utils import normalize_logprobs from axolotl.utils.data.utils import retry_on_request_exceptions LOG = logging.getLogger(__name__) @@ -58,6 +59,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): kd_online_server: Optional[str] = "vllm", kd_online_timeout: Optional[int] = 120, kd_cache_dir: Optional[str] = None, + kd_normalize_topk: Optional[bool] = True, **kwargs: Any, ): super().__init__(*args, **kwargs) @@ -78,6 +80,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): self.http_session = requests.Session() self.kd_online_timeout = kd_online_timeout self.kd_cache_dir = kd_cache_dir + self.kd_normalize_topk = kd_normalize_topk def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]: """ @@ -88,70 +91,15 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): [-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else [] ) - # Ensure raw_logprobs matches kd_online_topk length for tensor operations - # This should ideally be handled by the caller ensuring correct padding/truncation first - if len(raw_logprobs) != self.kd_online_topk: - # This case should be rare if pre-padding/truncation is done correctly - LOG.warning( - f"Logprobs length mismatch in _normalize_logprobs. " - f"Expected {self.kd_online_topk}, got {len(raw_logprobs)}. Will pad/truncate." - ) - padded_logprobs = raw_logprobs[: self.kd_online_topk] - if len(padded_logprobs) < self.kd_online_topk: - padded_logprobs.extend( - [-float("inf")] * (self.kd_online_topk - len(padded_logprobs)) - ) - raw_logprobs = padded_logprobs - - try: - position_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32) - - # Convert logprobs at T_online to probabilities - # use log sum exp trick to avoid underflow - position_logprobs_lse = torch.logsumexp( - position_logprobs_tensor, dim=-1, keepdim=True - ) - teacher_probs_t_online = torch.exp( - position_logprobs_tensor - position_logprobs_lse - ) - - # Normalize probabilities (sum to 1) - # This is important if the top-k from server aren't a full distribution - teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=0, keepdim=True) - if teacher_probs_t_online_sum.item() > 1e-9: - teacher_probs_t_online = ( - teacher_probs_t_online / teacher_probs_t_online_sum - ) - else: - # If sum is zero, create uniform distribution to avoid NaN/Inf later - # This can happen if all raw_logprobs are -inf - if self.kd_online_topk > 0: - teacher_probs_t_online = ( - torch.ones_like(teacher_probs_t_online) / self.kd_online_topk - ) - # else: leave as is, will result in -inf logprobs - # - # teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum( - # dim=0, keepdim=True - # ) - final_logprobs_tensor = torch.log(teacher_probs_t_online) - - return final_logprobs_tensor.tolist() - - except Exception as e: # pylint: disable=broad-exception-caught - LOG.error( - f"Error during online logprob scaling: {e}. Returning raw logprobs.", - exc_info=True, - ) - # Fallback to (padded/truncated) raw logprobs if scaling fails - return raw_logprobs + raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32) + return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist() @retry_on_request_exceptions(max_retries=10, delay=5) def fetch_online_logprobs_sglang( self, batch_input_ids: List[List[int]], labels: List[List[int]] ): """ - Fetches logprobs from an online teacher served by vllm for a batch of input_ids. + Fetches logprobs from an online teacher served by sglang for a batch of input_ids. Assumes API returns token IDs as strings in logprob dictionary keys. """ api_endpoint = f"{self.kd_online_server_base_url}/generate" @@ -267,10 +215,18 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): current_target_token_ids.append( pos_token_ids[: self.kd_online_topk] ) - scaled_logprobs_for_position = self._normalize_logprobs( - pos_logprobs_raw[: self.kd_online_topk] - ) - current_target_logprobs.append(scaled_logprobs_for_position) + + if self.kd_normalize_topk: + normalized_logprobs_for_position = self._normalize_logprobs( + pos_logprobs_raw[: self.kd_online_topk] + ) + current_target_logprobs.append( + normalized_logprobs_for_position + ) + else: + current_target_logprobs.append( + pos_logprobs_raw[: self.kd_online_topk] + ) # Mask depends on the corresponding label for the student if label == self.DEFAULT_LABEL_PAD_TOKEN_ID: @@ -442,12 +398,17 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): pos_token_ids[: self.kd_online_topk] ) - # normalized_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk]) - # current_target_logprobs.append(normalized_logprobs_for_position) - # don't normalize for now as the probs seem to sum to 1.0 already - current_target_logprobs.append( - pos_logprobs_raw[: self.kd_online_topk] - ) + if self.kd_normalize_topk: + normalized_logprobs_for_position = self._normalize_logprobs( + pos_logprobs_raw[: self.kd_online_topk] + ) + current_target_logprobs.append( + normalized_logprobs_for_position + ) + else: + current_target_logprobs.append( + pos_logprobs_raw[: self.kd_online_topk] + ) # Mask depends on the corresponding label for the student if label == self.DEFAULT_LABEL_PAD_TOKEN_ID: diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py index 0050ffe33..2aea80578 100644 --- a/src/axolotl/integrations/kd/kernels/liger.py +++ b/src/axolotl/integrations/kd/kernels/liger.py @@ -8,6 +8,8 @@ from liger_kernel.chunked_loss.fused_linear_distillation import ( LigerFusedLinearDistillationBase, ) +from axolotl.integrations.kd.utils import normalize_logprobs + class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): """ @@ -21,6 +23,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs target_mask_chunk: torch.Tensor, # [chunk_size, top_k] beta: float = 0.0, + normalize_topk: bool = True, ) -> torch.Tensor: """ Compute Top-K KL divergence loss for a chunk. @@ -33,9 +36,11 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): 0.0 for Forward KL (P_teacher || P_student). 1.0 for Reverse KL (P_student || P_teacher). 0.5 for Symmetric KL (average of Forward and Reverse). + normalize_topk: Whether to normalize the log probabilities Returns: Sum of KL divergence losses for the chunk. """ + topk = target_token_ids_chunk.shape[-1] student_logits_temp_scaled = ( # [chunk_size, vocab_size] student_logits_temp_scaled.float() ) @@ -56,6 +61,12 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): student_logits_topk_temp_scaled - student_lse ) + # we have the top-k student logprobs, normalize them + if normalize_topk: + student_logprobs_topk_temp_scaled = normalize_logprobs( + student_logprobs_topk_temp_scaled, topk + ) + valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k] student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask] @@ -67,7 +78,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): # Student probabilities P_student from log P_student student_probs_topk_valid = student_logprobs_topk_valid.exp() - kd_loss_per_token = torch.zeros_like(target_logprobs_valid) + # kd_loss_per_token = torch.zeros_like(target_logprobs_valid) # KL divergence: sum(P_teacher * (log P_teacher - log P_student)) # = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student) @@ -75,18 +86,33 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): # or as sum(P_teacher * (log_softmax_teacher - log_softmax_student)) # Here, target_logprobs_valid are log_softmax_teacher. # student_logprobs_topk_valid are log_softmax_student (for the selected K indices). - if beta < 1.0: # Contribution from Forward KL + if beta == 0.0: # Contribution from Forward KL fwd_kl_per_token = teacher_probs_valid * ( target_logprobs_valid - student_logprobs_topk_valid ) - kd_loss_per_token += (1.0 - beta) * fwd_kl_per_token - if beta > 0.0: # Contribution from Reverse KL + kd_loss = fwd_kl_per_token.sum() + elif beta == 1.0: # Contribution from Reverse KL rev_kl_per_token = student_probs_topk_valid * ( student_logprobs_topk_valid - target_logprobs_valid ) - kd_loss_per_token += beta * rev_kl_per_token - - kd_loss = kd_loss_per_token.sum() + kd_loss = rev_kl_per_token.sum() + else: + # JSD - Jensen-Shannon Divergence / Symmetric + mean_probs = ( + 1 - beta + ) * student_probs_topk_valid + beta * teacher_probs_valid + log_mean_probs = mean_probs.log() + student_kl = F.kl_div( + log_mean_probs, + student_logprobs_topk_valid, + reduction="sum", + log_target=True, + ) + teacher_kl = F.kl_div( + log_mean_probs, target_logprobs_valid, reduction="sum", log_target=True + ) + jsd_loss = beta * teacher_kl + (1 - beta) * student_kl + kd_loss = jsd_loss return kd_loss @@ -109,6 +135,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): compute_ce_loss: bool = True, temperature: float = 1.0, beta: float = 0.0, + normalize_topk: bool = True, ): # Compute student logits for the chunk from hidden states and LM head # student_input_chunk: [chunk_size, hidden_dim] @@ -144,6 +171,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): target_logprobs_chunk, target_mask_chunk, beta=beta, + normalize_topk=normalize_topk, ) return soft_loss, ce_loss @@ -167,6 +195,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): compiled: bool = False, chunk_size: int = 1024, compute_ce_loss: bool = True, + normalize_topk: bool = True, ): CHUNK_SIZE = chunk_size # pylint: disable=invalid-name grad_weight_acc = torch.zeros_like(student_lm_head_weight) @@ -211,6 +240,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): compute_ce_loss=compute_ce_loss, temperature=temperature, beta=beta, + normalize_topk=normalize_topk, ) def accumulate_chunk_grads( @@ -311,7 +341,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): ctx.bias_was_none = student_lm_head_bias is None ctx.orig_dims = (B, N, D, K) - # since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulated sum + # since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum # we still need to scale the kd_loss by the temp^2 kd_loss_acc = kd_loss_acc * (temperature**2) final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc @@ -397,6 +427,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module): compiled: bool = True, chunk_size: int = 1024, compute_ce_loss: bool = True, + normalize_topk: bool = True, ): super().__init__() if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0): @@ -412,6 +443,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module): self.compiled = compiled self.chunk_size = chunk_size self.compute_ce_loss = compute_ce_loss + self.normalize_topk = normalize_topk if not self.compute_ce_loss and self.weight_hard_loss > 0.0: print( @@ -449,4 +481,5 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module): self.compiled, self.chunk_size, self.compute_ce_loss, + self.normalize_topk, ) diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 131e1695d..91dae14c3 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -35,6 +35,7 @@ class AxolotlKDTrainer(AxolotlTrainer): self.args.kd_temperature, self.args.kd_beta, compute_ce_loss=bool(self.args.kd_ce_alpha), + normalize_topk=self.args.kd_normalize_topk, ) def _set_signature_columns_if_needed(self): diff --git a/src/axolotl/integrations/kd/utils.py b/src/axolotl/integrations/kd/utils.py new file mode 100644 index 000000000..7a3633596 --- /dev/null +++ b/src/axolotl/integrations/kd/utils.py @@ -0,0 +1,39 @@ +"""Helper KD utils""" + +import torch +from torch import FloatTensor + + +def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor: + """ + Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs. + """ + # Ensure raw_logprobs matches kd_online_topk length for tensor operations + # This should ideally be handled by the caller ensuring correct padding/truncation first + if logprobs.shape[-1] != topk: + # pad last dimension of logprobs to match topk length with -inf + padding_len = topk - logprobs.shape[-1] + padding_tensor = torch.full( + ( + *logprobs.shape[:-1], + padding_len, + ), # Takes all dimensions of logprobs except the last, then appends padding_needed + float("-inf"), + dtype=logprobs.dtype, + device=logprobs.device, + ) + logprobs = torch.cat((logprobs, padding_tensor), dim=-1) + + # Convert logprobs at T_online to probabilities + # use log sum exp trick to avoid underflow + position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True) + teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse) + + # Normalize probabilities (sum to 1) + # This is important if the top-k from server aren't a full distribution + teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True) + teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum + + final_logprobs_tensor = torch.log(teacher_probs_t_online) + + return final_logprobs_tensor