""" KD trainer """ from typing import Optional import torch from axolotl.core.trainers.base import AxolotlTrainer from axolotl.integrations.kd.kernels.kd import kd_loss_triton def kd_loss_function( student_logits, target_token_ids, target_logprobs, target_mask, num_items_in_batch: Optional[int] = None, kd_temperature: float = 1.0, ): # teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding # Determine the teacher sequence length teacher_seq_len = target_token_ids.shape[1] # Slice student logits to match the teacher-provided sequence length student_logits_for_kd = student_logits[ :, :teacher_seq_len, : ] # [B, teacher_seq_len, vocab_size] # Gather student logits for teacher's top-K tokens # shape -> [B, teacher_seq_len, K] student_logits_topk = torch.gather( student_logits_for_kd, dim=-1, index=target_token_ids ) # Apply KD temperature to student’s logits: # z_s(T) = z_s / T if kd_temperature != 1.0: student_logits_topk = student_logits_topk / kd_temperature # Convert student top-k logits to logprobs student_logprobs_topk = student_logits_topk - torch.logsumexp( student_logits_topk, dim=-1, keepdim=True ) # [B, seq_len, K] # Convert teacher_mask to boolean for indexing valid_mask = target_mask.bool() # Prune tensors to only keep valid tokens # This will result in 1D arrays of only valid positions student_logprobs_topk = student_logprobs_topk[valid_mask] # [N_valid_tokens] target_logprobs = target_logprobs[valid_mask] # [N_valid_tokens] # Since teacher_logprobs are already normalized, just exponentiate to get probabilities teacher_probs = target_logprobs.exp() # Compute forward KL: # KL = sum p^T_k (log p^T_k - log p^S_k), summed over all valid tokens. kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk) kd_loss = kd_loss_per_token.sum() # 9) Multiply by T^2 (classical KD scaling) if kd_temperature != 1.0: kd_loss = kd_loss * (kd_temperature**2) # Normalize by number of items or mean over valid tokens if num_items_in_batch is not None: # If you know how many items should be considered in the batch kd_loss = kd_loss / num_items_in_batch else: # Otherwise, just average over all valid tokens kd_loss = kd_loss / kd_loss_per_token.size(0) return kd_loss class AxolotlKDTrainer(AxolotlTrainer): """ Custom trainer subclass for Knowledge Distillation (KD) """ def _set_signature_columns_if_needed(self): super()._set_signature_columns_if_needed() columns_to_add = [] if self._signature_columns: if "target_logprobs" not in self._signature_columns: columns_to_add.append("target_logprobs") if "target_token_ids" not in self._signature_columns: columns_to_add.append("target_token_ids") if "target_mask" not in self._signature_columns: columns_to_add.append("target_mask") if columns_to_add: self._signature_columns += columns_to_add def compute_loss_w_triton( self, model, inputs, return_outputs=False, num_items_in_batch=None ): target_logprobs = inputs.pop("target_logprobs") target_token_ids = inputs.pop("target_token_ids") target_mask = inputs.pop("target_mask") if self.model_accepts_loss_kwargs: loss_kwargs = {} if num_items_in_batch is not None: loss_kwargs["num_items_in_batch"] = num_items_in_batch inputs = {**inputs, **loss_kwargs} outputs = model(**inputs) student_logits = outputs["logits"] # Slice or gather student logits to match teacher seq len # e.g.: teacher_seq_len = target_token_ids.shape[1] student_logits_for_kd = student_logits[ :, :teacher_seq_len, : ] # [B, seq_len, vocab_size] # GATHER top-K from student student_logits_topk = torch.gather( student_logits_for_kd, dim=-1, index=target_token_ids, # same shape [B, seq_len, K] ) # Now call the Triton-based KD loss kd_sum = kd_loss_triton( student_logits_topk, target_logprobs, # teacher logprobs [B, seq_len, K] target_mask, # mask [B, seq_len, K] ) # Normalize however you want if num_items_in_batch is not None: loss_kd = kd_sum / num_items_in_batch else: # or do e.g. average over valid tokens # quick example: total_valid = target_mask.sum() loss_kd = kd_sum / (total_valid + 1e-8) # optionally combine with CE loss if self.args.kd_ce_alpha > 0: kd_alpha = self.args.kd_alpha loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd else: loss = loss_kd return (loss, outputs) if return_outputs else loss def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. """ # return self.compute_loss_w_triton( # model, inputs, return_outputs, num_items_in_batch # ) target_logprobs = inputs.pop("target_logprobs") target_token_ids = inputs.pop("target_token_ids") target_mask = inputs.pop("target_mask") seq_len = target_token_ids.shape[1] if self.model_accepts_loss_kwargs: loss_kwargs = {} if num_items_in_batch is not None: loss_kwargs["num_items_in_batch"] = num_items_in_batch inputs = {**inputs, **loss_kwargs} outputs = model(**inputs) # FIXME: account for tokenizer.padding_side student_logits = outputs["logits"][:, :seq_len, :].contiguous() shift_logits = student_logits[..., :-1, :].contiguous() shift_target_logprobs = target_logprobs[..., 1:, :].contiguous() shift_target_token_ids = target_token_ids[..., 1:, :].contiguous() shift_target_mask = target_mask[..., 1:, :].contiguous() loss_kd = kd_loss_function( shift_logits, shift_target_token_ids, shift_target_logprobs, shift_target_mask, num_items_in_batch=num_items_in_batch, kd_temperature=self.args.kd_temperature, ) if self.args.kd_ce_alpha > 0: kd_alpha = self.args.kd_alpha loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd else: loss = loss_kd # Save past state if it exists # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[ # pylint: disable=attribute-defined-outside-init self.args.past_index ] if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: loss *= self.accelerator.num_processes torch.cuda.empty_cache() return (loss, outputs) if return_outputs else loss