From 75e1480c10cd5648decdb7c7415ca3e000e15597 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 24 Feb 2025 22:56:15 -0500 Subject: [PATCH] chunking not necessary --- .../kd/topk_logprob/forward_kl_triton.py | 93 +++---------------- 1 file changed, 13 insertions(+), 80 deletions(-) diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py index 5d7fef158..12bdda272 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py @@ -465,18 +465,17 @@ class TopKKLDivergence(torch.autograd.Function): # Wrapper function for chunked computation -def kl_div_loss_chunked( - student_logits, - target_token_ids, - target_logprobs, - target_mask, - num_items_in_batch=-1, - kd_temperature=1.0, - top_k_before_softmax=0, - n_chunks=1, +def loss( + student_logits: torch.Tensor, + target_token_ids: torch.Tensor, + target_logprobs: torch.Tensor, + target_mask: torch.Tensor, + num_items_in_batch: int = -1, + kd_temperature: float = 1.0, + top_k_before_softmax: int = 0, ): """ - Memory-efficient KL divergence loss computation. + Triton-accelerated Memory-efficient KL divergence loss computation for knowledge distillation Args: student_logits: Student logits [B, seq_len, vocab_size] @@ -486,81 +485,15 @@ def kl_div_loss_chunked( num_items_in_batch: Number of items for normalization (-1 for auto) kd_temperature: Temperature for KD top_k_before_softmax: Flag for softmax application order - n_chunks: Number of chunks to process (for memory efficiency) """ - batch_size = student_logits.shape[0] - - # If n_chunks <= 0, use the entire batch size - if n_chunks <= 0: - n_chunks = batch_size - - # Determine the actual number of chunks to use (find largest factor <= n_chunks) - factors = [i for i in range(1, batch_size + 1) if batch_size % i == 0] - actual_chunks = factors[ - min( - len(factors) - 1, - max( - 0, - next( - (i for i, f in enumerate(factors) if f >= n_chunks), - len(factors) - 1, - ), - ), - ) - ] - - # Compute chunk size - chunk_size = batch_size // actual_chunks - total_loss = 0.0 - - # Process in chunks - for i in range(0, batch_size, chunk_size): - chunk_end = min(i + chunk_size, batch_size) - chunk_loss = TopKKLDivergence.apply( - student_logits[i:chunk_end], - target_token_ids[i:chunk_end], - target_logprobs[i:chunk_end], - target_mask[i:chunk_end], - -1 if num_items_in_batch <= 0 else num_items_in_batch // actual_chunks, - kd_temperature, - top_k_before_softmax, - ) - total_loss += chunk_loss - - # Normalize by the number of chunks - return total_loss / actual_chunks - - -def loss( - student_logits: torch.Tensor, - target_token_ids: torch.Tensor, - target_logprobs: torch.Tensor, - target_mask: torch.Tensor, - num_items_in_batch: int = -1, - kd_temperature: float = 1.0, - top_k_before_softmax: int = 0, - n_chunks: int = 1, -) -> torch.Tensor: - """ - Triton-accelerated KL divergence loss for knowledge distillation. - - Args: - student_logits: Student model logits [B, seq_len, vocab_size] - target_token_ids: Teacher's top-k token IDs [B, seq_len, top_k] - target_logprobs: Teacher's top-k logprobs [B, seq_len, top_k] - target_mask: Mask for valid tokens [B, seq_len, top_k] - num_items_in_batch: Number of items for normalization (-1 for auto) - kd_temperature: Temperature for KD - top_k_before_softmax: Flag for softmax application order - n_chunks: Number of chunks for memory efficiency - """ - return kl_div_loss_chunked( + total_loss = TopKKLDivergence.apply( student_logits, target_token_ids, target_logprobs, target_mask, - num_items_in_batch, + -1 if num_items_in_batch <= 0 else num_items_in_batch, kd_temperature, top_k_before_softmax, - n_chunks, ) + + return total_loss