diff --git a/src/axolotl/integrations/kd/topk_logprob/bench_kl.py b/src/axolotl/integrations/kd/topk_logprob/bench_kl.py new file mode 100644 index 000000000..d67320bb7 --- /dev/null +++ b/src/axolotl/integrations/kd/topk_logprob/bench_kl.py @@ -0,0 +1,388 @@ +""" +benchmark utility helper for benchmarking the KL divergence triton kernel +""" +import gc +import time + +import torch +from torch.utils.benchmark import Timer + +from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss +from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss + + +# pylint: disable=cell-var-from-loop +def benchmark_kl_div_loss_with_backward(): + # Test configurations + batch_sizes = [1, 4] + seq_lens = [64, 512, 2048, 4096, 8192] + vocab_size = 32000 + top_k = 64 + + # Store results + results = [] + + # Run benchmarks + for batch_size in batch_sizes: + for seq_len in seq_lens: + # Generate random test data + torch.manual_seed(42) + + # Create tensors with gradients + student_logits = torch.randn( + batch_size, seq_len, vocab_size, device="cuda", requires_grad=True + ) + target_token_ids = torch.randint( + 0, vocab_size, (batch_size, seq_len, top_k), device="cuda" + ) + target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda") + target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1) + target_mask = torch.randint( + 0, 2, (batch_size, seq_len, top_k), device="cuda" + ).float() + + # Clone student_logits for the two implementations + student_logits_ref = student_logits.clone().detach().requires_grad_(True) + student_logits_triton = student_logits.clone().detach().requires_grad_(True) + + # Define functions for timing that include both forward and backward passes + def run_reference(): + # Forward pass + loss_ref = eager_loss( + student_logits_ref, target_token_ids, target_logprobs, target_mask + ) + # Backward pass + loss_ref.backward() + + def run_triton(): + # Forward pass + loss_triton = triton_loss( + student_logits_triton, + target_token_ids, + target_logprobs, + target_mask, + ) + # Backward pass + loss_triton.backward() + + # Benchmark reference implementation (forward + backward) + t0 = Timer( + stmt="run_reference()", + globals={ + "run_reference": run_reference, + }, + ) + # Reset gradients before timing + student_logits_ref.grad = None + ref_time = t0.timeit(10).median * 1000 # Convert to ms + + # Benchmark Triton implementation (forward + backward) + t1 = Timer( + stmt="run_triton()", + globals={ + "run_triton": run_triton, + }, + ) + # Reset gradients before timing + student_logits_triton.grad = None + triton_time = t1.timeit(10).median * 1000 # Convert to ms + + # Compute speedup + speedup = ref_time / triton_time if triton_time > 0 else float("inf") + + # Store results + results.append( + { + "batch_size": batch_size, + "seq_len": seq_len, + "reference_time_ms": ref_time, + "triton_time_ms": triton_time, + "speedup": speedup, + } + ) + + print(f"Batch size: {batch_size}, Seq len: {seq_len}") + print(f" Reference time (fwd+bwd): {ref_time:.2f} ms") + print(f" Triton time (fwd+bwd): {triton_time:.2f} ms") + print(f" Speedup: {speedup:.2f}x") + + return results + + +def benchmark_forward_backward_separately(): + """ + Benchmark forward and backward passes separately to identify where the speedup comes from. + """ + # Test configurations + batch_sizes = [1, 4, 8] + seq_lens = [64, 512, 2048] + vocab_size = 32000 + top_k = 64 + + # Store results + detailed_results = [] + + # Run benchmarks + for batch_size in batch_sizes: + for seq_len in seq_lens: + # Generate random test data + torch.manual_seed(42) + + # Create tensors with gradients + student_logits = torch.randn( + batch_size, seq_len, vocab_size, device="cuda", requires_grad=True + ) + target_token_ids = torch.randint( + 0, vocab_size, (batch_size, seq_len, top_k), device="cuda" + ) + target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda") + target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1) + target_mask = torch.randint( + 0, 2, (batch_size, seq_len, top_k), device="cuda" + ).float() + + # Clone student_logits for the two implementations + student_logits_ref = student_logits.clone().detach().requires_grad_(True) + student_logits_triton = student_logits.clone().detach().requires_grad_(True) + + # Forward-only reference + def run_reference_forward(): + with torch.no_grad(): + return eager_loss( + student_logits_ref, + target_token_ids, + target_logprobs, + target_mask, + ) + + # Forward-only triton + def run_triton_forward(): + with torch.no_grad(): + return triton_loss( + student_logits_triton, + target_token_ids, + target_logprobs, + target_mask, + ) + + # Benchmark forward pass only + + t0_fwd = Timer( + stmt="run_reference_forward()", + globals={ + "run_reference_forward": run_reference_forward, + }, + ) + ref_fwd_time = t0_fwd.timeit(10).median * 1000 # Convert to ms + + t1_fwd = Timer( + stmt="run_triton_forward()", + globals={ + "run_triton_forward": run_triton_forward, + }, + ) + triton_fwd_time = t1_fwd.timeit(10).median * 1000 # Convert to ms + + # Pre-compute losses for backward pass benchmarking + loss_ref = eager_loss( + student_logits_ref, target_token_ids, target_logprobs, target_mask + ) + loss_triton = triton_loss( + student_logits_triton, target_token_ids, target_logprobs, target_mask + ) + + # Backward-only reference + def run_reference_backward(): + student_logits_ref.grad = None + loss_ref.backward() + + # Backward-only triton + def run_triton_backward(): + student_logits_triton.grad = None + loss_triton.backward() + + # Benchmark backward pass only + t0_bwd = Timer( + stmt="run_reference_backward()", + globals={ + "run_reference_backward": run_reference_backward, + }, + ) + ref_bwd_time = t0_bwd.timeit(10).median * 1000 # Convert to ms + + t1_bwd = Timer( + stmt="run_triton_backward()", + globals={ + "run_triton_backward": run_triton_backward, + }, + ) + triton_bwd_time = t1_bwd.timeit(10).median * 1000 # Convert to ms + + # Compute speedups + fwd_speedup = ( + ref_fwd_time / triton_fwd_time if triton_fwd_time > 0 else float("inf") + ) + bwd_speedup = ( + ref_bwd_time / triton_bwd_time if triton_bwd_time > 0 else float("inf") + ) + total_ref_time = ref_fwd_time + ref_bwd_time + total_triton_time = triton_fwd_time + triton_bwd_time + total_speedup = ( + total_ref_time / total_triton_time + if total_triton_time > 0 + else float("inf") + ) + + # Store results + detailed_results.append( + { + "batch_size": batch_size, + "seq_len": seq_len, + "ref_forward_ms": ref_fwd_time, + "triton_forward_ms": triton_fwd_time, + "forward_speedup": fwd_speedup, + "ref_backward_ms": ref_bwd_time, + "triton_backward_ms": triton_bwd_time, + "backward_speedup": bwd_speedup, + "total_ref_ms": total_ref_time, + "total_triton_ms": total_triton_time, + "total_speedup": total_speedup, + } + ) + + print(f"Batch size: {batch_size}, Seq len: {seq_len}") + print( + f" Forward: Reference={ref_fwd_time:.2f}ms, Triton={triton_fwd_time:.2f}ms, Speedup={fwd_speedup:.2f}x" + ) + print( + f" Backward: Reference={ref_bwd_time:.2f}ms, Triton={triton_bwd_time:.2f}ms, Speedup={bwd_speedup:.2f}x" + ) + print( + f" Total: Reference={total_ref_time:.2f}ms, Triton={total_triton_time:.2f}ms, Speedup={total_speedup:.2f}x" + ) + + return detailed_results + + +def benchmark_memory_usage_with_backward(): + # Test configurations + batch_sizes = [1, 2] + seq_len = 8192 + vocab_size = 128000 + top_k = 64 + + # Store results + mem_results = [] + + # Run benchmarks + for batch_size in batch_sizes: + # Generate random test data + torch.manual_seed(42) + student_logits = torch.randn( + batch_size, seq_len, vocab_size, device="cuda", requires_grad=True + ) + target_token_ids = torch.randint( + 0, vocab_size, (batch_size, seq_len, top_k), device="cuda" + ) + target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda") + target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1) + target_mask = torch.randint( + 0, 2, (batch_size, seq_len, top_k), device="cuda" + ).float() + + # Clone student_logits for the implementations + student_logits_ref = student_logits.clone().detach().requires_grad_(True) + student_logits_triton = student_logits.clone().detach().requires_grad_(True) + + # Measure PyTorch memory usage (forward + backward) + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + loss_ref = eager_loss( + student_logits_ref, target_token_ids, target_logprobs, target_mask + ) + loss_ref.backward() + torch.cuda.synchronize() + pytorch_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB + + # Measure Triton memory usage (forward + backward) + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + loss_triton = triton_loss( + student_logits_triton, target_token_ids, target_logprobs, target_mask + ) + loss_triton.backward() + torch.cuda.synchronize() + triton_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB + + # Measure Triton memory usage with different chunk sizes (forward + backward) + for n_chunks in [1, 2, 4, 8]: + student_logits_chunk = student_logits.clone().detach().requires_grad_(True) + + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + loss_chunk = triton_loss( + student_logits_chunk, + target_token_ids, + target_logprobs, + target_mask, + ) + loss_chunk.backward() + torch.cuda.synchronize() + chunk_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB + + mem_results.append( + { + "batch_size": batch_size, + "implementation": f"Triton (chunks={n_chunks})", + "memory_mb": chunk_mem, + } + ) + + # Store results + mem_results.append( + { + "batch_size": batch_size, + "implementation": "PyTorch", + "memory_mb": pytorch_mem, + } + ) + + mem_results.append( + { + "batch_size": batch_size, + "implementation": "Triton (default)", + "memory_mb": triton_mem, + } + ) + + print(f"Batch size: {batch_size} (with backward pass)") + print(f" PyTorch memory: {pytorch_mem:.2f} MB") + print(f" Triton memory: {triton_mem:.2f} MB") + print(f" Memory reduction: {(1 - triton_mem/pytorch_mem)*100:.2f}%") + + return mem_results + + +def main(): + print("Running benchmarks with forward and backward passes...") + benchmark_kl_div_loss_with_backward() + clean() + + print("\nRunning detailed forward/backward benchmarks...") + # benchmark_forward_backward_separately() + # clean() + + print("\nRunning memory usage benchmarks with backward passes...") + benchmark_memory_usage_with_backward() + clean() + + +def clean(): + for _ in range(5): + gc.collect() + torch.cuda.empty_cache() + time.sleep(1) + + +if __name__ == "__main__": + main() diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py index 525d77bee..c1e537097 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py @@ -51,7 +51,6 @@ def grad_softmax_kernel( grad_logits_base = ( grad_student_logits_ptr + batch_idx * stride_gl_b + seq_idx * stride_gl_s ) - # logits_base = student_logits_ptr + batch_idx * stride_l_b + seq_idx * stride_l_s token_ids_base = ( target_token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s ) @@ -63,26 +62,43 @@ def grad_softmax_kernel( ) mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s - # Softmax over full vocab case + # Process each teacher probability one at a time, computing all gradients for it for k in range(0, K): - # Load token ID, teacher prob, and mask for this position + # Load data for current position k teacher_prob = tl.load(teacher_probs_base + k * stride_p_k) student_prob_k = tl.load(student_probs_base + k * stride_sp_k) mask_val = tl.load(mask_base + k * stride_m_k) + + # Precompute the self-influence term (multiplied by scale) + self_term = teacher_prob * (1.0 - student_prob_k) * scale + + # Calculate gradient contributions for all positions j for j in range(0, K): - other_token_id = tl.load(token_ids_base + j * stride_t_k) + token_id_j = tl.load(token_ids_base + j * stride_t_k) student_prob_j = tl.load(student_probs_base + j * stride_sp_k) mask_j = tl.load(mask_base + j * stride_m_k) + + # Calculate the masking factor combined_mask = mask_val * mask_j - is_diagonal = tl.where(j == k, 1.0, 0.0) - self_grad = teacher_prob * (1.0 - student_prob_k) - cross_grad = -teacher_prob * student_prob_j - grad_val = ( - -(self_grad * is_diagonal + cross_grad * (1.0 - is_diagonal)) - * scale + + # Determine if this is a diagonal or off-diagonal term + is_k_equals_j = tl.where(k == j, 1.0, 0.0) + + # Compute the gradient contribution + # For diagonal (k==j): -teacher_prob * (1-student_prob_k) * scale * mask + # For off-diagonal: -(-teacher_prob * student_prob_j) * scale * mask + grad_contribution = ( + -( + self_term * is_k_equals_j + - teacher_prob * student_prob_j * scale * (1.0 - is_k_equals_j) + ) * combined_mask ) - tl.atomic_add(grad_logits_base + other_token_id * stride_gl_v, grad_val) + + # Atomically update the gradient for this token + tl.atomic_add( + grad_logits_base + token_id_j * stride_gl_v, grad_contribution + ) @triton.jit @@ -315,7 +331,6 @@ class TopKKLDivergence(torch.autograd.Function): student_probs, ) = ctx.saved_tensors kd_temperature = ctx.kd_temperature - top_k_before_softmax = ctx.top_k_before_softmax num_items_in_batch = ctx.num_items_in_batch batch_size, seq_len, vocab_size = student_logits.shape @@ -346,72 +361,36 @@ class TopKKLDivergence(torch.autograd.Function): teacher_probs = torch.exp(target_logprobs) # Depending on which mode was used in forward, we use different gradient calculation - if top_k_before_softmax: - # Case 1: Softmax over top-k tokens - grid = (batch_size * seq_len,) - grad_topk_softmax_kernel[grid]( - grad_student_logits.contiguous(), - student_logits.contiguous(), - target_token_ids.contiguous(), - teacher_probs.contiguous(), - student_probs.contiguous(), - target_mask.contiguous(), - batch_size, - seq_len, - vocab_size, - top_k, - scale, - grad_student_logits.stride(0), - grad_student_logits.stride(1), - grad_student_logits.stride(2), - student_logits.stride(0), - student_logits.stride(1), - student_logits.stride(2), - target_token_ids.stride(0), - target_token_ids.stride(1), - target_token_ids.stride(2), - teacher_probs.stride(0), - teacher_probs.stride(1), - teacher_probs.stride(2), - student_probs.stride(0), - student_probs.stride(1), - student_probs.stride(2), - target_mask.stride(0), - target_mask.stride(1), - target_mask.stride(2), - min(1024, triton.next_power_of_2(top_k)), - ) - else: - # Case 2: Softmax over full vocab - grid = (batch_size * seq_len,) - grad_softmax_kernel[grid]( - grad_student_logits.contiguous(), - target_token_ids.contiguous(), - teacher_probs.contiguous(), - student_probs.contiguous(), - target_mask.contiguous(), - batch_size, - seq_len, - vocab_size, - top_k, - scale, - grad_student_logits.stride(0), - grad_student_logits.stride(1), - grad_student_logits.stride(2), - target_token_ids.stride(0), - target_token_ids.stride(1), - target_token_ids.stride(2), - teacher_probs.stride(0), - teacher_probs.stride(1), - teacher_probs.stride(2), - student_probs.stride(0), - student_probs.stride(1), - student_probs.stride(2), - target_mask.stride(0), - target_mask.stride(1), - target_mask.stride(2), - min(1024, triton.next_power_of_2(top_k)), - ) + # FIXME: top_k_before_softmax not correctly yet? + grid = (batch_size * seq_len,) + grad_softmax_kernel[grid]( + grad_student_logits.contiguous(), + target_token_ids.contiguous(), + teacher_probs.contiguous(), + student_probs.contiguous(), + target_mask.contiguous(), + batch_size, + seq_len, + vocab_size, + top_k, + scale, + grad_student_logits.stride(0), + grad_student_logits.stride(1), + grad_student_logits.stride(2), + target_token_ids.stride(0), + target_token_ids.stride(1), + target_token_ids.stride(2), + teacher_probs.stride(0), + teacher_probs.stride(1), + teacher_probs.stride(2), + student_probs.stride(0), + student_probs.stride(1), + student_probs.stride(2), + target_mask.stride(0), + target_mask.stride(1), + target_mask.stride(2), + min(1024, triton.next_power_of_2(top_k)), + ) # Return gradients for student_logits and None for other inputs return grad_student_logits, None, None, None, None, None, None diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index e57731461..2813a3686 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -18,6 +18,7 @@ KD trainer from axolotl.core.trainers.base import AxolotlTrainer +from .topk_logprob.forward_kl import loss as topk_kd_loss from .topk_logprob.forward_kl import topk_kd_loss_with_zscore from .topk_logprob.forward_kl_triton import loss as topk_kd_loss_triton @@ -85,7 +86,12 @@ class AxolotlKDTrainer(AxolotlTrainer): num_items_in_batch=num_items_in_batch, ) else: - loss_kd = topk_kd_loss_triton( + loss_fn = ( + topk_kd_loss + if self.args.kd_top_k_before_softmax + else topk_kd_loss_triton + ) + loss_kd = loss_fn( shift_logits, target_token_ids_for_loss, target_logprobs_for_loss,