From 1cc3a2d16c36d8a9453dc54e008133ffe8b4565d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 30 Dec 2024 21:34:46 -0500 Subject: [PATCH] make loss torch script compat --- .../kd/topk_logprob/forward_kl.py | 58 +++++++++---------- src/axolotl/integrations/kd/trainer.py | 4 -- 2 files changed, 28 insertions(+), 34 deletions(-) diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py index 4617d0d60..545223825 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py @@ -13,72 +13,70 @@ """ loss for top_k KL divergence """ -from typing import Optional - import torch +@torch.jit.script def loss( - student_logits, - target_token_ids, - target_logprobs, - target_mask, - num_items_in_batch: Optional[int] = None, + student_logits: torch.Tensor, + target_token_ids: torch.Tensor, + target_logprobs: torch.Tensor, + target_mask: torch.Tensor, + num_items_in_batch: int = -1, # Use -1 to indicate "None" kd_temperature: float = 1.0, -): - # teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding +) -> torch.Tensor: + """ + A KD loss function that is TorchScript-friendly. + """ # Determine the teacher sequence length - # _, teacher_seq_len, top_k = target_token_ids.shape + # target_token_ids shape: [B, teacher_seq_len, K] + # student_logits shape: [B, student_seq_len, vocab_size] teacher_seq_len = target_token_ids.shape[1] - # Slice student logits to match the teacher-provided sequence length + # Slice student logits to match teacher-provided sequence length student_logits_for_kd = student_logits[ :, :teacher_seq_len, : ] # [B, teacher_seq_len, vocab_size] # Gather student logits for teacher's top-K tokens - # shape -> [B, teacher_seq_len, K] student_logits_topk = torch.gather( student_logits_for_kd, dim=-1, index=target_token_ids - ) + ) # [B, teacher_seq_len, K] - # Apply KD temperature to student’s logits: - # z_s(T) = z_s / T + # Apply KD temperature to student’s logits if kd_temperature != 1.0: student_logits_topk = student_logits_topk / kd_temperature # Convert student top-k logits to logprobs student_logprobs_topk = student_logits_topk - torch.logsumexp( student_logits_topk, dim=-1, keepdim=True - ) # [B, seq_len, K] + ) # [B, teacher_seq_len, K] # Convert teacher_mask to boolean for indexing - valid_mask = target_mask.bool() + # In TorchScript, .bool() is sometimes unsupported, so we do: + valid_mask = target_mask.to(torch.bool) # Prune tensors to only keep valid tokens - # This will result in 1D arrays of only valid positions - student_logprobs_topk = student_logprobs_topk[valid_mask] # [N_valid_tokens] - target_logprobs = target_logprobs[valid_mask] # [N_valid_tokens] + student_logprobs_topk = student_logprobs_topk[valid_mask] + target_logprobs = target_logprobs[valid_mask] - # Since teacher_logprobs are already normalized, just exponentiate to get probabilities + # Convert teacher logprobs to probabilities teacher_probs = target_logprobs.exp() - # Compute forward KL: - # KL = sum p^T_k (log p^T_k - log p^S_k), summed over all valid tokens. + # Compute forward KL kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk) kd_loss = kd_loss_per_token.sum() - # 9) Multiply by T^2 (classical KD scaling) + # Multiply by T^2 (classical KD scaling) if kd_temperature != 1.0: kd_loss = kd_loss * (kd_temperature**2) - # Normalize by number of items or mean over valid tokens - if num_items_in_batch is not None: - # If you know how many items should be considered in the batch - kd_loss = kd_loss / num_items_in_batch + # Normalize by number of items (if provided) or by valid tokens + if num_items_in_batch > 0: + kd_loss = kd_loss / float(num_items_in_batch) else: - # Otherwise, just average over all valid tokens - kd_loss = kd_loss / kd_loss_per_token.size(0) + # Fall back to average over valid tokens + kd_loss = kd_loss / float(kd_loss_per_token.size(0)) return kd_loss diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 9d686299e..7eda30659 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -16,8 +16,6 @@ KD trainer """ -import torch - from axolotl.core.trainers.base import AxolotlTrainer from .topk_logprob.forward_kl import loss as topk_kd_loss @@ -106,6 +104,4 @@ class AxolotlKDTrainer(AxolotlTrainer): if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: loss *= self.accelerator.num_processes - torch.cuda.empty_cache() - return (loss, outputs) if return_outputs else loss