make loss torch script compat

This commit is contained in:
Wing Lian
2024-12-30 21:34:46 -05:00
parent 287d2ca8d5
commit 1cc3a2d16c
2 changed files with 28 additions and 34 deletions

View File

@@ -13,72 +13,70 @@
""" """
loss for top_k KL divergence loss for top_k KL divergence
""" """
from typing import Optional
import torch import torch
@torch.jit.script
def loss( def loss(
student_logits, student_logits: torch.Tensor,
target_token_ids, target_token_ids: torch.Tensor,
target_logprobs, target_logprobs: torch.Tensor,
target_mask, target_mask: torch.Tensor,
num_items_in_batch: Optional[int] = None, num_items_in_batch: int = -1, # Use -1 to indicate "None"
kd_temperature: float = 1.0, kd_temperature: float = 1.0,
): ) -> torch.Tensor:
# teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding """
A KD loss function that is TorchScript-friendly.
"""
# Determine the teacher sequence length # Determine the teacher sequence length
# _, teacher_seq_len, top_k = target_token_ids.shape # target_token_ids shape: [B, teacher_seq_len, K]
# student_logits shape: [B, student_seq_len, vocab_size]
teacher_seq_len = target_token_ids.shape[1] teacher_seq_len = target_token_ids.shape[1]
# Slice student logits to match the teacher-provided sequence length # Slice student logits to match teacher-provided sequence length
student_logits_for_kd = student_logits[ student_logits_for_kd = student_logits[
:, :teacher_seq_len, : :, :teacher_seq_len, :
] # [B, teacher_seq_len, vocab_size] ] # [B, teacher_seq_len, vocab_size]
# Gather student logits for teacher's top-K tokens # Gather student logits for teacher's top-K tokens
# shape -> [B, teacher_seq_len, K]
student_logits_topk = torch.gather( student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids student_logits_for_kd, dim=-1, index=target_token_ids
) ) # [B, teacher_seq_len, K]
# Apply KD temperature to students logits: # Apply KD temperature to students logits
# z_s(T) = z_s / T
if kd_temperature != 1.0: if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature student_logits_topk = student_logits_topk / kd_temperature
# Convert student top-k logits to logprobs # Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp( student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True student_logits_topk, dim=-1, keepdim=True
) # [B, seq_len, K] ) # [B, teacher_seq_len, K]
# Convert teacher_mask to boolean for indexing # Convert teacher_mask to boolean for indexing
valid_mask = target_mask.bool() # In TorchScript, .bool() is sometimes unsupported, so we do:
valid_mask = target_mask.to(torch.bool)
# Prune tensors to only keep valid tokens # Prune tensors to only keep valid tokens
# This will result in 1D arrays of only valid positions student_logprobs_topk = student_logprobs_topk[valid_mask]
student_logprobs_topk = student_logprobs_topk[valid_mask] # [N_valid_tokens] target_logprobs = target_logprobs[valid_mask]
target_logprobs = target_logprobs[valid_mask] # [N_valid_tokens]
# Since teacher_logprobs are already normalized, just exponentiate to get probabilities # Convert teacher logprobs to probabilities
teacher_probs = target_logprobs.exp() teacher_probs = target_logprobs.exp()
# Compute forward KL: # Compute forward KL
# KL = sum p^T_k (log p^T_k - log p^S_k), summed over all valid tokens.
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk) kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum() kd_loss = kd_loss_per_token.sum()
# 9) Multiply by T^2 (classical KD scaling) # Multiply by T^2 (classical KD scaling)
if kd_temperature != 1.0: if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2) kd_loss = kd_loss * (kd_temperature**2)
# Normalize by number of items or mean over valid tokens # Normalize by number of items (if provided) or by valid tokens
if num_items_in_batch is not None: if num_items_in_batch > 0:
# If you know how many items should be considered in the batch kd_loss = kd_loss / float(num_items_in_batch)
kd_loss = kd_loss / num_items_in_batch
else: else:
# Otherwise, just average over all valid tokens # Fall back to average over valid tokens
kd_loss = kd_loss / kd_loss_per_token.size(0) kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss return kd_loss

View File

@@ -16,8 +16,6 @@
KD trainer KD trainer
""" """
import torch
from axolotl.core.trainers.base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss from .topk_logprob.forward_kl import loss as topk_kd_loss
@@ -106,6 +104,4 @@ class AxolotlKDTrainer(AxolotlTrainer):
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes loss *= self.accelerator.num_processes
torch.cuda.empty_cache()
return (loss, outputs) if return_outputs else loss return (loss, outputs) if return_outputs else loss