From 9b1164b841943e131335a88da80f3936f07b635f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 21 Dec 2024 13:54:21 -0500 Subject: [PATCH] no log etc --- src/axolotl/integrations/kd/kernels/kd.py | 69 ++++++++++++----------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/src/axolotl/integrations/kd/kernels/kd.py b/src/axolotl/integrations/kd/kernels/kd.py index 36aacfef2..6e0a87735 100644 --- a/src/axolotl/integrations/kd/kernels/kd.py +++ b/src/axolotl/integrations/kd/kernels/kd.py @@ -2,7 +2,6 @@ import torch import triton import triton.language as tl - # -------------------------------------------------------- # Triton Kernel for forward pass # -------------------------------------------------------- @@ -30,21 +29,21 @@ import triton.language as tl @triton.jit def kd_forward_kernel( - student_logits_ptr, # float32[B, seq_len, K] after gather - teacher_logprobs_ptr, # float32[B, seq_len, K] - mask_ptr, # bool[B, seq_len, K] or int8 - partial_kd_ptr, # float32[B, seq_len] (accumulator) - B, # total batch size - seq_len, # total sequence length from teacher - K, # top-K from teacher - BLOCK_SIZE: tl.constexpr # how many tokens per block in dimension0 + student_logits_ptr, # float32[B, seq_len, K] after gather + teacher_logprobs_ptr, # float32[B, seq_len, K] + mask_ptr, # bool[B, seq_len, K] or int8 + partial_kd_ptr, # float32[B, seq_len] (accumulator) + B, # total batch size + seq_len, # total sequence length from teacher + K, # top-K from teacher + BLOCK_SIZE: tl.constexpr, # how many tokens per block in dimension0 ): # program_id is the global index for each block pid = tl.program_id(0) # Each block handles a range of seq positions in [0..B*seq_len) block_start = pid * BLOCK_SIZE - block_end = tl.min((pid+1)*BLOCK_SIZE, B * seq_len) + block_end = tl.min((pid + 1) * BLOCK_SIZE, B * seq_len) length = block_end - block_start # Offsets for indexing @@ -69,7 +68,7 @@ def kd_forward_kernel( # For K top tokens, read the relevant student logits and teacher logprobs # We'll load them in a small loop: - logsumexp_val = float('-inf') + logsumexp_val = float("-inf") # We'll store them in a local array for a second pass student_logits_k = [0.0 for _ in range(K)] teacher_logprobs_k = [0.0 for _ in range(K)] @@ -79,26 +78,17 @@ def kd_forward_kernel( for k in range(K): # load student logit student_val = tl.load( - student_logits_ptr - + b_idx*seq_len*K - + s_idx*K - + k, - mask=(b_idx < B) and (s_idx < seq_len) + student_logits_ptr + b_idx * seq_len * K + s_idx * K + k, + mask=(b_idx < B) and (s_idx < seq_len), ) teacher_val = tl.load( - teacher_logprobs_ptr - + b_idx*seq_len*K - + s_idx*K - + k, - mask=(b_idx < B) and (s_idx < seq_len) + teacher_logprobs_ptr + b_idx * seq_len * K + s_idx * K + k, + mask=(b_idx < B) and (s_idx < seq_len), ) # get mask mask_val = tl.load( - mask_ptr - + b_idx*seq_len*K - + s_idx*K - + k, - mask=(b_idx < B) and (s_idx < seq_len) + mask_ptr + b_idx * seq_len * K + s_idx * K + k, + mask=(b_idx < B) and (s_idx < seq_len), ) student_logits_k[k] = student_val @@ -119,7 +109,7 @@ def kd_forward_kernel( # safe check if exp_sum == 0.0: exp_sum = 1e-8 - logsumexp_val = logsumexp_val + float(torch.log(torch.tensor(exp_sum))) + logsumexp_val = logsumexp_val + tl.log(torch.tensor(exp_sum)) # compute partial kl # sum_{k in valid} p^T_k (log p^T_k - log p^S_k) @@ -155,7 +145,7 @@ class _KLDivergenceTritonFn(torch.autograd.Function): BLOCK_SIZE = 1024 # compute how many blocks we need total_positions = B * seq_len - grid = ( (total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE , ) + grid = ((total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE,) partial_kd = torch.empty( grid[0], dtype=student_logits.dtype, device=student_logits.device @@ -167,8 +157,10 @@ class _KLDivergenceTritonFn(torch.autograd.Function): teacher_logprobs, mask, partial_kd, - B, seq_len, K, - BLOCK_SIZE=BLOCK_SIZE + B, + seq_len, + K, + BLOCK_SIZE=BLOCK_SIZE, ) # Sum partials on CPU or GPU @@ -186,7 +178,12 @@ class _KLDivergenceTritonFn(torch.autograd.Function): # Save context for backward # Typically, you'd need to save the raw student_logits, teacher_logprobs, etc. for grad # But be mindful of memory usage. We’ll demonstrate the minimal approach here: - ctx.save_for_backward(student_logits, teacher_logprobs, mask, torch.tensor(num_items_in_batch or 0)) + ctx.save_for_backward( + student_logits, + teacher_logprobs, + mask, + torch.tensor(num_items_in_batch or 0), + ) ctx.B = B ctx.seq_len = seq_len ctx.K = K @@ -238,7 +235,9 @@ class _KLDivergenceTritonFn(torch.autograd.Function): # treat student_logits as if it requires grad stl = student_logits.clone().detach().requires_grad_(True) # compute logsumexp along K - logsumexp_val = torch.logsumexp(stl, dim=-1, keepdim=True) # [B, seq_len, 1] + logsumexp_val = torch.logsumexp( + stl, dim=-1, keepdim=True + ) # [B, seq_len, 1] student_logprobs_topk = stl - logsumexp_val teacher_probs = teacher_logprobs.exp() # p^S_k @@ -266,7 +265,7 @@ def kd_loss_triton( student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K teacher_logprobs, mask, - num_items_in_batch=None + num_items_in_batch=None, ): """ Wrapper that calls our Triton-based forward+backward for KD. @@ -274,4 +273,6 @@ def kd_loss_triton( or inside a separate kernel. This function expects that you've *already* called gather on student_logits -> shape [B, seq_len, K]. """ - return _KLDivergenceTritonFn.apply(student_logits, teacher_logprobs, mask, num_items_in_batch) + return _KLDivergenceTritonFn.apply( + student_logits, teacher_logprobs, mask, num_items_in_batch + )