diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py index b1ec058d4..525d77bee 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py @@ -6,73 +6,15 @@ import torch import triton import triton.language as tl - -# Helper function for computing logsumexp -@triton.jit -def logsumexp_kernel( - logits_ptr, - output_ptr, - B, - S, - V, # batch size, seq len, vocab size - stride_b, - stride_s, - stride_v, - out_stride_b, - out_stride_s, - BLOCK_SIZE: tl.constexpr, -): - # Program ID - pid = tl.program_id(0) - batch_idx = pid // S - seq_idx = pid % S - - # Bounds check - if batch_idx >= B or seq_idx >= S: - return - - # Pointers - logits_base = logits_ptr + batch_idx * stride_b + seq_idx * stride_s - - # Find maximum for numerical stability - max_val = -float("inf") - for v_offset in range(0, V, BLOCK_SIZE): - v_size = min(BLOCK_SIZE, V - v_offset) - mask = tl.arange(0, BLOCK_SIZE) < v_size - - logits_block = tl.load( - logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v, - mask=mask, - other=-float("inf"), - ) - max_val = tl.maximum(max_val, tl.max(logits_block, axis=0)) - - # Compute sum of exp(logit - max_val) - sum_exp = 0.0 - for v_offset in range(0, V, BLOCK_SIZE): - v_size = min(BLOCK_SIZE, V - v_offset) - mask = tl.arange(0, BLOCK_SIZE) < v_size - - logits_block = tl.load( - logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v, - mask=mask, - other=-float("inf"), - ) - sum_exp += tl.sum(tl.exp(logits_block - max_val), axis=0) - - # Compute logsumexp - result = max_val + tl.log(sum_exp) - - # Store result - tl.store(output_ptr + batch_idx * out_stride_b + seq_idx * out_stride_s, result) +from .logsumexp import logsumexp_kernel @triton.jit def grad_softmax_kernel( grad_student_logits_ptr, - student_logits_ptr, target_token_ids_ptr, teacher_probs_ptr, + student_probs_ptr, mask_ptr, B, S, @@ -82,15 +24,15 @@ def grad_softmax_kernel( stride_gl_b, stride_gl_s, stride_gl_v, - stride_l_b, - stride_l_s, - stride_l_v, stride_t_b, stride_t_s, stride_t_k, stride_p_b, stride_p_s, stride_p_k, + stride_sp_b, + stride_sp_s, + stride_sp_k, stride_m_b, stride_m_s, stride_m_k, @@ -116,21 +58,31 @@ def grad_softmax_kernel( teacher_probs_base = ( teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s ) + student_probs_base = ( + student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s + ) mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s # Softmax over full vocab case for k in range(0, K): # Load token ID, teacher prob, and mask for this position - token_id = tl.load(token_ids_base + k * stride_t_k) teacher_prob = tl.load(teacher_probs_base + k * stride_p_k) + student_prob_k = tl.load(student_probs_base + k * stride_sp_k) mask_val = tl.load(mask_base + k * stride_m_k) - - # Apply mask by scaling gradient to zero if masked - grad_val = teacher_prob * scale * mask_val - - # Update the gradient for this token's position in the vocabulary - # Only contributes if mask_val is non-zero - tl.atomic_add(grad_logits_base + token_id * stride_gl_v, grad_val) + for j in range(0, K): + other_token_id = tl.load(token_ids_base + j * stride_t_k) + student_prob_j = tl.load(student_probs_base + j * stride_sp_k) + mask_j = tl.load(mask_base + j * stride_m_k) + combined_mask = mask_val * mask_j + is_diagonal = tl.where(j == k, 1.0, 0.0) + self_grad = teacher_prob * (1.0 - student_prob_k) + cross_grad = -teacher_prob * student_prob_j + grad_val = ( + -(self_grad * is_diagonal + cross_grad * (1.0 - is_diagonal)) + * scale + * combined_mask + ) + tl.atomic_add(grad_logits_base + other_token_id * stride_gl_v, grad_val) @triton.jit @@ -338,7 +290,6 @@ class TopKKLDivergence(torch.autograd.Function): ) kd_loss = token_losses.sum() - # pylint: disable=duplicate-code # Apply temperature scaling if kd_temperature != 1.0: kd_loss = kd_loss * (kd_temperature**2) @@ -376,7 +327,7 @@ class TopKKLDivergence(torch.autograd.Function): # Compute scaling factor scale = grad_output.item() - # Apply temperature scaling + # Apply temperature scaling from forward pass if kd_temperature != 1.0: scale = scale * (kd_temperature**2) @@ -386,7 +337,8 @@ class TopKKLDivergence(torch.autograd.Function): else: scale = scale / float(target_mask.sum().item()) - # If we used temperature scaling in the forward pass, we need to apply it in the backward pass + # Apply chain rule for temperature scaling (1/temperature) + # This comes from d(logits/temperature)/d(logits) = 1/temperature if kd_temperature != 1.0: scale = scale / kd_temperature @@ -434,9 +386,9 @@ class TopKKLDivergence(torch.autograd.Function): grid = (batch_size * seq_len,) grad_softmax_kernel[grid]( grad_student_logits.contiguous(), - student_logits.contiguous(), target_token_ids.contiguous(), teacher_probs.contiguous(), + student_probs.contiguous(), target_mask.contiguous(), batch_size, seq_len, @@ -446,15 +398,15 @@ class TopKKLDivergence(torch.autograd.Function): grad_student_logits.stride(0), grad_student_logits.stride(1), grad_student_logits.stride(2), - student_logits.stride(0), - student_logits.stride(1), - student_logits.stride(2), target_token_ids.stride(0), target_token_ids.stride(1), target_token_ids.stride(2), teacher_probs.stride(0), teacher_probs.stride(1), teacher_probs.stride(2), + student_probs.stride(0), + student_probs.stride(1), + student_probs.stride(2), target_mask.stride(0), target_mask.stride(1), target_mask.stride(2), diff --git a/src/axolotl/integrations/kd/topk_logprob/logsumexp.py b/src/axolotl/integrations/kd/topk_logprob/logsumexp.py new file mode 100644 index 000000000..113e50c3c --- /dev/null +++ b/src/axolotl/integrations/kd/topk_logprob/logsumexp.py @@ -0,0 +1,67 @@ +""" +Optimized Triton kernels for logsumexp +""" +# pylint: disable=invalid-name,unused-argument +import triton +import triton.language as tl + + +# Helper function for computing logsumexp +@triton.jit +def logsumexp_kernel( + logits_ptr, + output_ptr, + B, + S, + V, # batch size, seq len, vocab size + stride_b, + stride_s, + stride_v, + out_stride_b, + out_stride_s, + BLOCK_SIZE: tl.constexpr, +): + # Program ID + # pylint: disable=duplicate-code + pid = tl.program_id(0) + batch_idx = pid // S + seq_idx = pid % S + + # Bounds check + if batch_idx >= B or seq_idx >= S: + return + + # Pointers + logits_base = logits_ptr + batch_idx * stride_b + seq_idx * stride_s + + # Find maximum for numerical stability + max_val = -float("inf") + for v_offset in range(0, V, BLOCK_SIZE): + v_size = min(BLOCK_SIZE, V - v_offset) + mask = tl.arange(0, BLOCK_SIZE) < v_size + + logits_block = tl.load( + logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v, + mask=mask, + other=-float("inf"), + ) + max_val = tl.maximum(max_val, tl.max(logits_block, axis=0)) + + # Compute sum of exp(logit - max_val) + sum_exp = 0.0 + for v_offset in range(0, V, BLOCK_SIZE): + v_size = min(BLOCK_SIZE, V - v_offset) + mask = tl.arange(0, BLOCK_SIZE) < v_size + + logits_block = tl.load( + logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v, + mask=mask, + other=-float("inf"), + ) + sum_exp += tl.sum(tl.exp(logits_block - max_val), axis=0) + + # Compute logsumexp + result = max_val + tl.log(sum_exp) + + # Store result + tl.store(output_ptr + batch_idx * out_stride_b + seq_idx * out_stride_s, result) diff --git a/tests/e2e/integrations/test_kl_loss.py b/tests/e2e/integrations/test_kl_loss.py new file mode 100644 index 000000000..52a789166 --- /dev/null +++ b/tests/e2e/integrations/test_kl_loss.py @@ -0,0 +1,162 @@ +""" +sanity checks on kl loss and gradients +""" +import torch + +# Import both implementations +from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss +from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss + + +def test_kl_loss_gradient(): + """Test that the gradient of the Triton implementation matches the eager implementation.""" + + # Set the random seed for reproducibility + torch.manual_seed(42) + + # Create random inputs + batch_size = 2 + seq_len = 3 + vocab_size = 100 + top_k = 5 + + # Generate random student logits + student_logits = torch.randn( + batch_size, seq_len, vocab_size, requires_grad=True, device="cuda" + ) + student_logits_triton = student_logits.detach().clone().requires_grad_(True) + + # Generate random target token IDs, ensuring they're valid indices + target_token_ids = torch.randint( + 0, vocab_size, (batch_size, seq_len, top_k), device="cuda" + ) + + # Generate random target logprobs (before normalization) + target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda") + + # Normalize the target logprobs to ensure they form a valid distribution + target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1) + + # Create a random mask with some tokens masked out + target_mask = torch.randint( + 0, 2, (batch_size, seq_len, top_k), device="cuda" + ).float() + + # Additional parameters + num_items_in_batch = batch_size * seq_len + kd_temperature = 1.0 + top_k_before_softmax = 0 # Test both modes + + # Compute the loss and gradients with eager implementation + loss_eager = eager_loss( + student_logits, + target_token_ids, + target_logprobs, + target_mask, + num_items_in_batch, + kd_temperature, + top_k_before_softmax, + ) + loss_eager.backward() + grad_eager = student_logits.grad.clone() + + # Reset gradients + student_logits.grad.zero_() + + # Compute the loss and gradients with Triton implementation + loss_triton = triton_loss( + student_logits_triton, + target_token_ids, + target_logprobs, + target_mask, + num_items_in_batch, + kd_temperature, + top_k_before_softmax, + ) + loss_triton.backward() + grad_triton = student_logits_triton.grad.clone() + + # Compare loss values + print(f"Eager loss: {loss_eager.item()}") + print(f"Triton loss: {loss_triton.item()}") + loss_diff = abs(loss_eager.item() - loss_triton.item()) + print(f"Loss difference: {loss_diff}") + assert loss_diff < 1e-5, "Loss values differ significantly!" + + # Compare gradients + grad_diff = (grad_eager - grad_triton).abs().max().item() + print(f"Max gradient difference: {grad_diff}") + + # Print some sample gradients + sample_idx = (0, 0, 0) # (batch, seq, vocab) + print(f"Sample eager gradient: {grad_eager[sample_idx].item()}") + print(f"Sample triton gradient: {grad_triton[sample_idx].item()}") + + # Compute relative difference for non-zero gradients + mask = grad_eager.abs() > 1e-10 + if mask.sum() > 0: + rel_diff = ( + ( + (grad_eager[mask] - grad_triton[mask]).abs() + / (grad_eager[mask].abs() + 1e-10) + ) + .max() + .item() + ) + print(f"Max relative gradient difference: {rel_diff}") + assert rel_diff < 1e-3, "Gradients differ significantly!" + + # Also test top_k_before_softmax = 1 mode + top_k_before_softmax = 1 + + # Reset the gradients + student_logits = torch.randn( + batch_size, seq_len, vocab_size, requires_grad=True, device="cuda" + ) + student_logits_triton = student_logits.detach().clone().requires_grad_(True) + + # Compute the loss and gradients with eager implementation + loss_eager = eager_loss( + student_logits, + target_token_ids, + target_logprobs, + target_mask, + num_items_in_batch, + kd_temperature, + top_k_before_softmax, + ) + loss_eager.backward() + grad_eager = student_logits.grad.clone() + + # Compute the loss and gradients with Triton implementation + loss_triton = triton_loss( + student_logits_triton, + target_token_ids, + target_logprobs, + target_mask, + num_items_in_batch, + kd_temperature, + top_k_before_softmax, + ) + loss_triton.backward() + grad_triton = student_logits_triton.grad.clone() + + # Compare gradients for top_k_before_softmax = 1 + grad_diff = (grad_eager - grad_triton).abs().max().item() + print("\nWith top_k_before_softmax=1:") + print(f"Max gradient difference: {grad_diff}") + + # Compute relative difference for non-zero gradients + mask = grad_eager.abs() > 1e-10 + if mask.sum() > 0: + rel_diff = ( + ( + (grad_eager[mask] - grad_triton[mask]).abs() + / (grad_eager[mask].abs() + 1e-10) + ) + .max() + .item() + ) + assert ( + rel_diff < 1e-3 + ), f"Gradients differ significantly, Max relative gradient difference: {rel_diff}" diff --git a/tests/e2e/integrations/test_logsumexp.py b/tests/e2e/integrations/test_logsumexp.py new file mode 100644 index 000000000..b508d2442 --- /dev/null +++ b/tests/e2e/integrations/test_logsumexp.py @@ -0,0 +1,204 @@ +""" +sanity checks on logsumexp kernel validity +""" +import torch +import triton + +from axolotl.integrations.kd.topk_logprob.logsumexp import logsumexp_kernel + + +# PyTorch implementation of logsumexp for reference +def torch_logsumexp(logits): + """PyTorch implementation of logsumexp over last dimension""" + return torch.logsumexp(logits, dim=-1) + + +# Wrapper function for Triton logsumexp kernel +def triton_logsumexp(logits): + """Triton implementation of logsumexp over last dimension""" + B, S, V = logits.shape # pylint: disable=invalid-name + output = torch.empty((B, S), dtype=torch.float32, device=logits.device) + + grid = (B * S,) + logsumexp_kernel[grid]( + logits.contiguous(), + output, + B, + S, + V, + logits.stride(0), + logits.stride(1), + logits.stride(2), + output.stride(0), + output.stride(1), + min(1024, triton.next_power_of_2(V)), + ) + + return output + + +class TritonLogSumExp(torch.autograd.Function): + """ + Wrap a custom autograd function to use the Triton logsumexp for gradient testing + """ + + @staticmethod + def forward(ctx, logits): + B, S, V = logits.shape # pylint: disable=invalid-name + output = torch.empty((B, S), dtype=torch.float32, device=logits.device) + + # Save inputs for backward pass + ctx.save_for_backward(logits) + ctx.shape = logits.shape + + grid = (B * S,) + logsumexp_kernel[grid]( + logits.contiguous(), + output, + B, + S, + V, + logits.stride(0), + logits.stride(1), + logits.stride(2), + output.stride(0), + output.stride(1), + min(1024, triton.next_power_of_2(V)), + ) + + return output + + @staticmethod + def backward(ctx, grad_output): + (logits,) = ctx.saved_tensors + + # For logsumexp, the gradient is softmax(input) * grad_output + # First compute the logsumexp + lse = TritonLogSumExp.apply(logits) + + # Compute softmax by exponentiating differences + softmax_output = torch.exp(logits - lse.unsqueeze(-1)) + + # Compute gradient of logsumexp by multiplying the softmax output by the gradient + grad_input = softmax_output * grad_output.unsqueeze(-1) + + return grad_input + + +def test_logsumexp_values(): + """Test that the Triton logsumexp implementation matches PyTorch's""" + # Set random seed for reproducibility + torch.manual_seed(42) + + # Test with various input shapes + test_shapes = [ + (2, 3, 10), # small vocab + (4, 5, 100), # medium vocab + (2, 2, 32000), # large vocab (typical for LLMs) + ] + + for shape in test_shapes: + # Create random input tensors + logits = torch.randn(shape, device="cuda", requires_grad=False) + + # Compute logsumexp using both implementations + torch_result = torch_logsumexp(logits) + triton_result = triton_logsumexp(logits) + + # Compare results + max_diff = (torch_result - triton_result).abs().max().item() + print(f"Shape {shape}, Max diff: {max_diff}") + + # Assert that the results are very close + assert max_diff < 1e-5, f"Results differ for shape {shape}: max diff {max_diff}" + + +def test_logsumexp_edge_cases(): + """Test edge cases for numerical stability""" + # Set random seed for reproducibility + torch.manual_seed(42) + + # Case 1: Very large values that might cause overflow + logits_large = torch.ones(2, 3, 100, device="cuda") * 1000 + + # Case 2: Very small values that might cause underflow + logits_small = torch.ones(2, 3, 100, device="cuda") * -1000 + + # Case 3: Mix of large and small values + logits_mixed = torch.zeros(2, 3, 100, device="cuda") + logits_mixed[:, :, 0] = 1000 # One very large value + + # Case 4: All identical values + logits_identical = torch.ones(2, 3, 100, device="cuda") * 5 + + # Case 5: Extreme values with NaN check + logits_extreme = torch.cat( + [ + torch.full((1, 3, 50), 1e10, device="cuda"), + torch.full((1, 3, 50), -1e10, device="cuda"), + ], + dim=0, + ) + + for i, logits in enumerate( + [logits_large, logits_small, logits_mixed, logits_identical, logits_extreme] + ): + # Compute logsumexp using both implementations + torch_result = torch_logsumexp(logits) + triton_result = triton_logsumexp(logits) + + # Check for NaNs + assert not torch.isnan( + torch_result + ).any(), f"PyTorch produced NaNs for case {i+1}" + assert not torch.isnan( + triton_result + ).any(), f"Triton produced NaNs for case {i+1}" + + # Compare results + max_diff = (torch_result - triton_result).abs().max().item() + print(f"Edge case {i+1}, Max diff: {max_diff}") + + # For very extreme values, allow a bit more tolerance + if i == 4: # extreme case + assert max_diff < 1e-2, f"Results differ too much for edge case {i+1}" + else: + assert max_diff < 1e-5, f"Results differ too much for edge case {i+1}" + + +def test_logsumexp_gradients(): + """Test that the gradients of Triton logsumexp match PyTorch's""" + # Set random seed for reproducibility + torch.manual_seed(42) + + # Create input tensors with gradients enabled + shapes = [(2, 3, 10), (4, 5, 100)] + + for shape in shapes: + # Create two identical tensors for PyTorch and Triton + logits_torch = torch.randn(shape, device="cuda", requires_grad=True) + logits_triton = logits_torch.clone().detach().requires_grad_(True) + + # Forward pass + torch_output = torch_logsumexp(logits_torch) + triton_output = TritonLogSumExp.apply(logits_triton) + + # Compare forward pass values + max_diff_forward = (torch_output - triton_output).abs().max().item() + assert max_diff_forward < 1e-5, f"Forward pass values differ for shape {shape}" + + # Create random gradient + grad_output = torch.randn_like(torch_output) + + # Backward pass + torch_output.backward(grad_output) + triton_output.backward(grad_output) + + # Compare gradients + max_diff_grad = (logits_torch.grad - logits_triton.grad).abs().max().item() + print(f"Shape {shape}, Max gradient diff: {max_diff_grad}") + + # Assert that gradients are very close + assert ( + max_diff_grad < 1e-5 + ), f"Gradients differ for shape {shape}: max diff {max_diff_grad}"