From 68e97d032ad379ee1580c1b562bfa957c3c940f9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 26 Feb 2025 04:44:24 -0500 Subject: [PATCH] chunk to prevent overflows in kernel --- .../kd/topk_logprob/forward_kl_triton.py | 359 ++++++++++++++---- 1 file changed, 279 insertions(+), 80 deletions(-) diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py index e79d799df..82b1b5671 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py @@ -2,6 +2,8 @@ Optimized Triton kernel for KL divergence loss between teacher and student models. """ # pylint: disable=invalid-name,unused-argument +from typing import Optional, Tuple + import torch import triton import triton.language as tl @@ -316,12 +318,44 @@ def grad_topk_softmax_kernel( tl.atomic_add(grad_logits_base + other_token_id * stride_gl_v, grad_val) +# Triton-accelerated implementation of KL divergence loss for top-k tokens +# Chunking helper functions for handling long sequences +def chunk_tensor( + tensor: torch.Tensor, max_seq_len: int +) -> Tuple[torch.Tensor, Optional[int]]: + """Split a tensor along sequence dimension if needed.""" + _, seq_len, *__ = tensor.shape + + if seq_len <= max_seq_len: + return tensor, None + + num_chunks = (seq_len + max_seq_len - 1) // max_seq_len + chunks = [] + + for i in range(num_chunks): + start_idx = i * max_seq_len + end_idx = min((i + 1) * max_seq_len, seq_len) + chunks.append(tensor[:, start_idx:end_idx, ...]) + + return chunks, num_chunks + + +def merge_chunks(chunks: list, original_shape: torch.Size): + """Merge chunks back into a single tensor with original shape.""" + return torch.cat(chunks, dim=1) + + # Triton-accelerated implementation of KL divergence loss for top-k tokens class TopKKLDivergence(torch.autograd.Function): """ Autograd function for KL divergence loss between top-k logprobs + with support for chunking to handle very long sequences. """ + # Max sequence length to process in a single kernel launch + # This is a tunable parameter that might need adjustment based on GPU memory + MAX_SEQ_LEN = 8192 + @staticmethod def forward( ctx, @@ -334,7 +368,7 @@ class TopKKLDivergence(torch.autograd.Function): top_k_before_softmax=0, ): """ - Forward pass for KL divergence loss between top-k logprobs. + Forward pass for KL divergence loss between top-k logprobs with chunking. """ # Only convert target_logprobs to float, leave student_logits as is target_logprobs = target_logprobs.float() @@ -346,52 +380,145 @@ class TopKKLDivergence(torch.autograd.Function): # Slice student logits to match teacher sequence length student_logits_for_kd = student_logits[:, :teacher_seq_len, :] - if top_k_before_softmax: - # Apply temperature to student logits - if kd_temperature != 1.0: - student_logits_for_kd = student_logits_for_kd / kd_temperature + # Store original values for backward pass + ctx.original_seq_len = teacher_seq_len + ctx.original_dtype = student_logits.dtype - # Gather student logits for top-k tokens - student_logits_topk = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids + # Apply chunking for long sequences + if teacher_seq_len > TopKKLDivergence.MAX_SEQ_LEN: + # Chunk the inputs + student_logits_chunks, num_chunks = chunk_tensor( + student_logits_for_kd, TopKKLDivergence.MAX_SEQ_LEN ) + target_token_ids_chunks, _ = chunk_tensor( + target_token_ids, TopKKLDivergence.MAX_SEQ_LEN + ) + # target_logprobs_chunks, _ = chunk_tensor( + # target_logprobs, TopKKLDivergence.MAX_SEQ_LEN + # ) + # target_mask_chunks, _ = chunk_tensor( + # target_mask, TopKKLDivergence.MAX_SEQ_LEN + # ) + + # Process each chunk + student_logprobs_chunks = [] + student_probs_chunks = [] + + for i in range(num_chunks): + chunk_logits = student_logits_chunks[i] + chunk_token_ids = target_token_ids_chunks[i] + chunk_seq_len = chunk_logits.shape[1] + + if top_k_before_softmax: + # Apply temperature to student logits + if kd_temperature != 1.0: + chunk_logits = chunk_logits / kd_temperature + + # Gather student logits for top-k tokens + chunk_logits_topk = torch.gather( + chunk_logits, dim=-1, index=chunk_token_ids + ) + + # Compute softmax over gathered logits + chunk_logprobs_topk = torch.log_softmax(chunk_logits_topk, dim=-1) + chunk_probs_topk = torch.exp(chunk_logprobs_topk) + else: + # Allocate output tensor for logprobs directly (always in float32) + chunk_logprobs_topk = torch.empty( + (batch_size, chunk_seq_len, top_k), + dtype=torch.float32, + device=chunk_logits.device, + ) + + # Launch fused kernel directly + grid = (batch_size * chunk_seq_len,) + fused_logsumexp_logprobs_kernel[grid]( + chunk_logits.contiguous(), + chunk_logprobs_topk, + chunk_token_ids.contiguous(), + batch_size, + chunk_seq_len, + vocab_size, + top_k, + kd_temperature, + chunk_logits.stride(0), + chunk_logits.stride(1), + chunk_logits.stride(2), + chunk_logprobs_topk.stride(0), + chunk_logprobs_topk.stride(1), + chunk_logprobs_topk.stride(2), + chunk_token_ids.stride(0), + chunk_token_ids.stride(1), + chunk_token_ids.stride(2), + min(1024, triton.next_power_of_2(vocab_size)), + ) + + # Calculate probs from logprobs + chunk_probs_topk = torch.exp(chunk_logprobs_topk) + + # Store results + student_logprobs_chunks.append(chunk_logprobs_topk) + student_probs_chunks.append(chunk_probs_topk) + + # Merge results + student_logprobs_topk = torch.cat(student_logprobs_chunks, dim=1) + student_probs_topk = torch.cat(student_probs_chunks, dim=1) + + # Save chunking info for backward pass + ctx.used_chunking = True + ctx.num_chunks = num_chunks - # Compute softmax over gathered logits - student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1) - student_probs_topk = torch.exp(student_logprobs_topk) else: - # Allocate output tensor for logprobs directly (always in float32) - student_logprobs_topk = torch.empty( - (batch_size, teacher_seq_len, top_k), - dtype=torch.float32, - device=student_logits.device, - ) + # Original code path for shorter sequences + if top_k_before_softmax: + # Apply temperature to student logits + if kd_temperature != 1.0: + student_logits_for_kd = student_logits_for_kd / kd_temperature - # Launch fused kernel directly - grid = (batch_size * teacher_seq_len,) - fused_logsumexp_logprobs_kernel[grid]( - student_logits_for_kd.contiguous(), - student_logprobs_topk, - target_token_ids.contiguous(), - batch_size, - teacher_seq_len, - vocab_size, - top_k, - kd_temperature, - student_logits_for_kd.stride(0), - student_logits_for_kd.stride(1), - student_logits_for_kd.stride(2), - student_logprobs_topk.stride(0), - student_logprobs_topk.stride(1), - student_logprobs_topk.stride(2), - target_token_ids.stride(0), - target_token_ids.stride(1), - target_token_ids.stride(2), - min(1024, triton.next_power_of_2(vocab_size)), - ) + # Gather student logits for top-k tokens + student_logits_topk = torch.gather( + student_logits_for_kd, dim=-1, index=target_token_ids + ) - # Calculate probs from logprobs - student_probs_topk = torch.exp(student_logprobs_topk) + # Compute softmax over gathered logits + student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1) + student_probs_topk = torch.exp(student_logprobs_topk) + else: + # Allocate output tensor for logprobs directly (always in float32) + student_logprobs_topk = torch.empty( + (batch_size, teacher_seq_len, top_k), + dtype=torch.float32, + device=student_logits.device, + ) + + # Launch fused kernel directly + grid = (batch_size * teacher_seq_len,) + fused_logsumexp_logprobs_kernel[grid]( + student_logits_for_kd.contiguous(), + student_logprobs_topk, + target_token_ids.contiguous(), + batch_size, + teacher_seq_len, + vocab_size, + top_k, + kd_temperature, + student_logits_for_kd.stride(0), + student_logits_for_kd.stride(1), + student_logits_for_kd.stride(2), + student_logprobs_topk.stride(0), + student_logprobs_topk.stride(1), + student_logprobs_topk.stride(2), + target_token_ids.stride(0), + target_token_ids.stride(1), + target_token_ids.stride(2), + min(1024, triton.next_power_of_2(vocab_size)), + ) + + # Calculate probs from logprobs + student_probs_topk = torch.exp(student_logprobs_topk) + + # No chunking used + ctx.used_chunking = False # Save tensors for backward pass ctx.save_for_backward( @@ -408,9 +535,20 @@ class TopKKLDivergence(torch.autograd.Function): # Convert mask to boolean valid_mask = target_mask.bool() - # Extract valid tokens only - student_logprobs_valid = student_logprobs_topk[valid_mask] - target_logprobs_valid = target_logprobs[valid_mask] + # Extract valid tokens only - this is where the error was happening + # Use cloned contiguous tensors and explicit indexing for safety + student_logprobs_flat = student_logprobs_topk.view(-1, top_k) + target_logprobs_flat = target_logprobs.view(-1, top_k) + valid_mask_flat = valid_mask.view(-1, top_k) + + # Gather valid indices explicitly to avoid illegal memory access + valid_indices = torch.nonzero(valid_mask_flat.view(-1)).squeeze(-1) + student_logprobs_valid = torch.index_select( + student_logprobs_flat.view(-1), 0, valid_indices + ) + target_logprobs_valid = torch.index_select( + target_logprobs_flat.view(-1), 0, valid_indices + ) # Convert teacher logprobs to probabilities teacher_probs_valid = torch.exp(target_logprobs_valid) @@ -430,14 +568,15 @@ class TopKKLDivergence(torch.autograd.Function): if num_items_in_batch > 0: kd_loss = kd_loss / float(num_items_in_batch) else: - kd_loss = kd_loss / float(token_losses.size(0)) + num_valid_tokens = valid_indices.numel() + kd_loss = kd_loss / float(num_valid_tokens if num_valid_tokens > 0 else 1) return kd_loss @staticmethod def backward(ctx, grad_output): """ - Optimized backward pass for KL divergence loss with proper dtype handling. + Optimized backward pass for KL divergence loss with proper dtype handling and chunking. """ ( student_logits, @@ -448,11 +587,11 @@ class TopKKLDivergence(torch.autograd.Function): ) = ctx.saved_tensors kd_temperature = ctx.kd_temperature num_items_in_batch = ctx.num_items_in_batch + original_dtype = ctx.original_dtype - # Store original dtype for later conversion - original_dtype = student_logits.dtype - batch_size, seq_len, vocab_size = student_logits.shape - _, _, top_k = target_token_ids.shape + # Get dimensions + batch_size, _, vocab_size = student_logits.shape + _, teacher_seq_len, top_k = target_token_ids.shape # Initialize gradient tensor in float32 to support atomic operations grad_student_logits = torch.zeros_like(student_logits, dtype=torch.float32) @@ -477,36 +616,89 @@ class TopKKLDivergence(torch.autograd.Function): # Convert teacher logprobs to probabilities teacher_probs = torch.exp(target_logprobs) - # Launch gradient computation kernel - grid = (batch_size * seq_len,) - grad_softmax_kernel[grid]( - grad_student_logits.contiguous(), - target_token_ids.contiguous(), - teacher_probs.contiguous(), - student_probs.contiguous(), - target_mask.contiguous(), - batch_size, - seq_len, - vocab_size, - top_k, - scale, - grad_student_logits.stride(0), - grad_student_logits.stride(1), - grad_student_logits.stride(2), - target_token_ids.stride(0), - target_token_ids.stride(1), - target_token_ids.stride(2), - teacher_probs.stride(0), - teacher_probs.stride(1), - teacher_probs.stride(2), - student_probs.stride(0), - student_probs.stride(1), - student_probs.stride(2), - target_mask.stride(0), - target_mask.stride(1), - target_mask.stride(2), - min(1024, triton.next_power_of_2(top_k)), - ) + # Use chunking for the backward pass if used in forward + if getattr(ctx, "used_chunking", False): + num_chunks = ctx.num_chunks + max_seq = TopKKLDivergence.MAX_SEQ_LEN + + # Process each chunk + for i in range(num_chunks): + start_idx = i * max_seq + end_idx = min((i + 1) * max_seq, teacher_seq_len) + chunk_len = end_idx - start_idx + + # Get chunk slices + # student_logits_chunk = student_logits[:, start_idx:end_idx, :] + target_token_ids_chunk = target_token_ids[:, start_idx:end_idx, :] + teacher_probs_chunk = teacher_probs[:, start_idx:end_idx, :] + student_probs_chunk = student_probs[:, start_idx:end_idx, :] + target_mask_chunk = target_mask[:, start_idx:end_idx, :] + grad_student_logits_chunk = grad_student_logits[:, start_idx:end_idx, :] + + # Launch gradient computation kernel for this chunk + grid = (batch_size * chunk_len,) + grad_softmax_kernel[grid]( + grad_student_logits_chunk.contiguous(), + target_token_ids_chunk.contiguous(), + teacher_probs_chunk.contiguous(), + student_probs_chunk.contiguous(), + target_mask_chunk.contiguous(), + batch_size, + chunk_len, + vocab_size, + top_k, + scale, + grad_student_logits_chunk.stride(0), + grad_student_logits_chunk.stride(1), + grad_student_logits_chunk.stride(2), + target_token_ids_chunk.stride(0), + target_token_ids_chunk.stride(1), + target_token_ids_chunk.stride(2), + teacher_probs_chunk.stride(0), + teacher_probs_chunk.stride(1), + teacher_probs_chunk.stride(2), + student_probs_chunk.stride(0), + student_probs_chunk.stride(1), + student_probs_chunk.stride(2), + target_mask_chunk.stride(0), + target_mask_chunk.stride(1), + target_mask_chunk.stride(2), + min(1024, triton.next_power_of_2(top_k)), + ) + + # Update the gradient tensor (already in-place) + else: + # Original code path for shorter sequences + # Launch gradient computation kernel + grid = (batch_size * teacher_seq_len,) + grad_softmax_kernel[grid]( + grad_student_logits.contiguous(), + target_token_ids.contiguous(), + teacher_probs.contiguous(), + student_probs.contiguous(), + target_mask.contiguous(), + batch_size, + teacher_seq_len, + vocab_size, + top_k, + scale, + grad_student_logits.stride(0), + grad_student_logits.stride(1), + grad_student_logits.stride(2), + target_token_ids.stride(0), + target_token_ids.stride(1), + target_token_ids.stride(2), + teacher_probs.stride(0), + teacher_probs.stride(1), + teacher_probs.stride(2), + student_probs.stride(0), + student_probs.stride(1), + student_probs.stride(2), + target_mask.stride(0), + target_mask.stride(1), + target_mask.stride(2), + min(1024, triton.next_power_of_2(top_k)), + ) # Convert gradient back to original dtype if needed if original_dtype != torch.float32: @@ -525,9 +717,11 @@ def loss( num_items_in_batch: int = -1, kd_temperature: float = 1.0, top_k_before_softmax: int = 0, + max_seq_len: Optional[int] = None, ): """ Triton-accelerated Memory-efficient KL divergence loss computation for knowledge distillation + with support for very long sequences. Args: student_logits: Student logits [B, seq_len, vocab_size] @@ -537,7 +731,12 @@ def loss( num_items_in_batch: Number of items for normalization (-1 for auto) kd_temperature: Temperature for KD top_k_before_softmax: Flag for softmax application order + max_seq_len: Override default MAX_SEQ_LEN value for chunking """ + # Allow overriding the max sequence length + if max_seq_len is not None and max_seq_len > 0: + TopKKLDivergence.MAX_SEQ_LEN = max_seq_len + total_loss = TopKKLDivergence.apply( student_logits, target_token_ids,