diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py index c1e537097..f921f80bb 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py @@ -6,7 +6,125 @@ import torch import triton import triton.language as tl -from .logsumexp import logsumexp_kernel + +@triton.jit +def fused_logsumexp_logprobs_kernel( + student_logits_ptr, # Input logits in original dtype + student_logprobs_ptr, # Output logprobs (float32) + token_ids_ptr, # Token IDs for top-k + B, + S, + V, + K, # batch size, seq len, vocab size, top-k + temperature, + stride_l_b, + stride_l_s, + stride_l_v, + stride_lp_b, + stride_lp_s, + stride_lp_k, + stride_t_b, + stride_t_s, + stride_t_k, + BLOCK_SIZE: tl.constexpr, +): + """ + Fused kernel that computes logsumexp and logprobs for topk tokens. + All computations are done in float32 for numerical stability. + """ + # Program ID + pid = tl.program_id(0) + batch_idx = pid // S + seq_idx = pid % S + + # Bounds check + if batch_idx >= B or seq_idx >= S: + return + + # Compute logsumexp over the vocabulary + max_val = -float("inf") + + # Phase 1: Find max value across vocabulary + for v_offset in range(0, V, BLOCK_SIZE): + # Create block indices and mask + block_size = min(BLOCK_SIZE, V - v_offset) + block_idx = tl.arange(0, BLOCK_SIZE) + mask = block_idx < block_size + + # Load logits block and convert to float32 in-place + ptrs = ( + student_logits_ptr + + batch_idx * stride_l_b + + seq_idx * stride_l_s + + (v_offset + block_idx) * stride_l_v + ) + block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32) + + # Apply temperature scaling if needed + if temperature != 1.0: + block_logits = block_logits / temperature + + # Update max value + block_max = tl.max(block_logits, axis=0) + max_val = tl.maximum(max_val, block_max) + + # Phase 2: Compute sum of exp(logits - max_val) + sum_exp = 0.0 + + for v_offset in range(0, V, BLOCK_SIZE): + # Create block indices and mask + block_size = min(BLOCK_SIZE, V - v_offset) + block_idx = tl.arange(0, BLOCK_SIZE) + mask = block_idx < block_size + + # Load logits block and convert to float32 in-place + ptrs = ( + student_logits_ptr + + batch_idx * stride_l_b + + seq_idx * stride_l_s + + (v_offset + block_idx) * stride_l_v + ) + block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32) + + # Apply temperature scaling if needed + if temperature != 1.0: + block_logits = block_logits / temperature + + # Compute exp(logits - max_val) and add to sum + block_exp = tl.exp(block_logits - max_val) + sum_exp += tl.sum(block_exp * mask, axis=0) + + # Compute final logsumexp + logsumexp = max_val + tl.log(sum_exp) + + # Phase 3: Compute and store logprobs for the top-k tokens + token_ids_base = token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s + logprobs_base = ( + student_logprobs_ptr + batch_idx * stride_lp_b + seq_idx * stride_lp_s + ) + + for k in range(K): + # Load token ID for position k + token_id = tl.load(token_ids_base + k * stride_t_k) + + # Load the corresponding logit and convert to float32 + token_logit_ptr = ( + student_logits_ptr + + batch_idx * stride_l_b + + seq_idx * stride_l_s + + token_id * stride_l_v + ) + token_logit = tl.load(token_logit_ptr).to(tl.float32) + + # Apply temperature scaling if needed + if temperature != 1.0: + token_logit = token_logit / temperature + + # Compute logprob directly: logit - logsumexp + token_logprob = token_logit - logsumexp + + # Store the result + tl.store(logprobs_base + k * stride_lp_k, token_logprob) @triton.jit @@ -218,64 +336,61 @@ class TopKKLDivergence(torch.autograd.Function): """ Forward pass for KL divergence loss between top-k logprobs. """ - # Convert inputs to appropriate types - student_logits = student_logits.float() + # Only convert target_logprobs to float, leave student_logits as is target_logprobs = target_logprobs.float() # Get dimensions batch_size, _, vocab_size = student_logits.shape - _, teacher_seq_len, _ = target_token_ids.shape + _, teacher_seq_len, top_k = target_token_ids.shape # Slice student logits to match teacher sequence length student_logits_for_kd = student_logits[:, :teacher_seq_len, :] if top_k_before_softmax: - # 1. Apply temperature to student logits + # Apply temperature to student logits if kd_temperature != 1.0: student_logits_for_kd = student_logits_for_kd / kd_temperature - # 2. Gather student logits for top-k tokens + # Gather student logits for top-k tokens student_logits_topk = torch.gather( student_logits_for_kd, dim=-1, index=target_token_ids ) - # 3. Compute softmax over gathered logits + # Compute softmax over gathered logits student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1) student_probs_topk = torch.exp(student_logprobs_topk) else: - # 1. Apply temperature to student logits - if kd_temperature != 1.0: - student_logits_for_kd = student_logits_for_kd / kd_temperature - - # 2. Gather student logits for top-k tokens - student_logits_topk = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) - - # 3. Compute logsumexp over full vocabulary using Triton - student_lse = torch.empty( - (batch_size, teacher_seq_len), + # Allocate output tensor for logprobs directly (always in float32) + student_logprobs_topk = torch.empty( + (batch_size, teacher_seq_len, top_k), dtype=torch.float32, device=student_logits.device, ) + # Launch fused kernel directly grid = (batch_size * teacher_seq_len,) - logsumexp_kernel[grid]( + fused_logsumexp_logprobs_kernel[grid]( student_logits_for_kd.contiguous(), - student_lse, + student_logprobs_topk, + target_token_ids.contiguous(), batch_size, teacher_seq_len, vocab_size, + top_k, + kd_temperature, student_logits_for_kd.stride(0), student_logits_for_kd.stride(1), student_logits_for_kd.stride(2), - student_lse.stride(0), - student_lse.stride(1), + student_logprobs_topk.stride(0), + student_logprobs_topk.stride(1), + student_logprobs_topk.stride(2), + target_token_ids.stride(0), + target_token_ids.stride(1), + target_token_ids.stride(2), min(1024, triton.next_power_of_2(vocab_size)), ) - # 4. Convert to logprobs - student_logprobs_topk = student_logits_topk - student_lse.unsqueeze(-1) + # Calculate probs from logprobs student_probs_topk = torch.exp(student_logprobs_topk) # Save tensors for backward pass @@ -321,7 +436,7 @@ class TopKKLDivergence(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): """ - Optimized backward pass for KL divergence loss. + Optimized backward pass for KL divergence loss with proper dtype handling. """ ( student_logits, @@ -333,11 +448,13 @@ class TopKKLDivergence(torch.autograd.Function): kd_temperature = ctx.kd_temperature num_items_in_batch = ctx.num_items_in_batch + # Store original dtype for later conversion + original_dtype = student_logits.dtype batch_size, seq_len, vocab_size = student_logits.shape _, _, top_k = target_token_ids.shape - # Initialize gradient tensor - grad_student_logits = torch.zeros_like(student_logits) + # Initialize gradient tensor in float32 to support atomic operations + grad_student_logits = torch.zeros_like(student_logits, dtype=torch.float32) # Compute scaling factor scale = grad_output.item() @@ -353,15 +470,13 @@ class TopKKLDivergence(torch.autograd.Function): scale = scale / float(target_mask.sum().item()) # Apply chain rule for temperature scaling (1/temperature) - # This comes from d(logits/temperature)/d(logits) = 1/temperature if kd_temperature != 1.0: scale = scale / kd_temperature # Convert teacher logprobs to probabilities teacher_probs = torch.exp(target_logprobs) - # Depending on which mode was used in forward, we use different gradient calculation - # FIXME: top_k_before_softmax not correctly yet? + # Launch gradient computation kernel grid = (batch_size * seq_len,) grad_softmax_kernel[grid]( grad_student_logits.contiguous(), @@ -392,6 +507,10 @@ class TopKKLDivergence(torch.autograd.Function): min(1024, triton.next_power_of_2(top_k)), ) + # Convert gradient back to original dtype if needed + if original_dtype != torch.float32: + grad_student_logits = grad_student_logits.to(original_dtype) + # Return gradients for student_logits and None for other inputs return grad_student_logits, None, None, None, None, None, None