diff --git a/src/axolotl/core/trainers/kd.py b/src/axolotl/core/trainers/kd.py index 3b15c8484..f90e7d02e 100644 --- a/src/axolotl/core/trainers/kd.py +++ b/src/axolotl/core/trainers/kd.py @@ -108,13 +108,21 @@ class AxolotlKDTrainer(AxolotlTrainer): ) # Now call the Triton-based KD loss - loss_kd = kd_loss_triton( + kd_sum = kd_loss_triton( student_logits_topk, target_logprobs, # teacher logprobs [B, seq_len, K] target_mask, # mask [B, seq_len, K] - num_items_in_batch=num_items_in_batch, ) + # Normalize however you want + if num_items_in_batch is not None: + loss_kd = kd_sum / num_items_in_batch + else: + # or do e.g. average over valid tokens + # quick example: + total_valid = target_mask.sum() + loss_kd = kd_sum / (total_valid + 1e-8) + # optionally combine with CE loss if self.args.kd_ce_alpha > 0: loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd diff --git a/src/axolotl/integrations/kd/kernels/kd.py b/src/axolotl/integrations/kd/kernels/kd.py index 2cd19aa64..f45717a01 100644 --- a/src/axolotl/integrations/kd/kernels/kd.py +++ b/src/axolotl/integrations/kd/kernels/kd.py @@ -29,236 +29,215 @@ import triton.language as tl @triton.jit def kd_forward_kernel( - student_logits_ptr, # float32[B, seq_len, K] after gather - teacher_logprobs_ptr, # float32[B, seq_len, K] - mask_ptr, # bool[B, seq_len, K] or int8 - partial_kd_ptr, # float32[B, seq_len] (accumulator) - B, # total batch size - seq_len, # total sequence length from teacher - K, # top-K from teacher - BLOCK_SIZE: tl.constexpr, # how many tokens per block in dimension0 + # student_logits after gather: [B, seq_len, K] flattened to 1D in row-major + student_logits_ptr: tl.tensor, + # teacher_logprobs: [B, seq_len, K] flattened + teacher_logprobs_ptr: tl.tensor, + # mask: [B, seq_len, K] flattened (bool or 0/1) + mask_ptr: tl.tensor, + # partial_kd: [B*seq_len] flattened buffer to store partial sums + partial_kd_ptr: tl.tensor, + B: tl.int32, + seq_len: tl.int32, + K: tl.int32, + BLOCK_SIZE: tl.constexpr, ): - # program_id is the global index for each block + """ + For each position in [0..B*seq_len), we: + - gather the K student logits + - compute logsumexp + - compute the KL sum = sum_{k} t_prob_k * ( t_log_k - s_logprob_k ) + - store that partial sum into partial_kd_ptr[offset]. + """ + # 1) Identify which [B*seq_len] index this block handles pid = tl.program_id(0) - # Each block handles a range of seq positions in [0..B*seq_len) - block_start = pid * BLOCK_SIZE - block_end = tl.min((pid + 1) * BLOCK_SIZE, B * seq_len) - length = block_end - block_start + # 2) Vector of [0..BLOCK_SIZE) local offsets + offsets = tl.arange(0, BLOCK_SIZE) + # 3) Global indices = pid * BLOCK_SIZE + offsets + idx = pid * BLOCK_SIZE + offsets - # Offsets for indexing - # We want to interpret a linear index in [0..B*seq_len) as (batch_idx, seq_idx) - # E.g.: - # batch_idx = block_start // seq_len - # seq_idx = block_start % seq_len - # but we must do this for each element in the block. We'll do that inside a loop. + # 4) Mask to ensure we don’t read out-of-bounds + total_positions = B * seq_len + mask_pos = idx < total_positions - # We'll store a running partial KL sum in registers - # We do a for-loop for each position in the block, then do a thread-level reduction - kd_reg = 0.0 + # 5) Convert a 1D `idx` => (b_idx, s_idx) + # b_idx is the batch number, s_idx is the sequence position + b_idx = idx // seq_len + s_idx = idx % seq_len - # We'll iterate over each item in [block_start, block_end). - # A more advanced approach can use vectorization / warp-based parallelism inside the block. - for offset in range(length): - # Convert offset -> actual index in [0..B*seq_len) - linear_idx = block_start + offset - # batch index and sequence index - b_idx = linear_idx // seq_len - s_idx = linear_idx % seq_len + # We'll accumulate the KL for each index in a register array + kl_acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - # For K top tokens, read the relevant student logits and teacher logprobs - # We'll load them in a small loop: - logsumexp_val = float("-inf") - # We'll store them in a local array for a second pass - student_logits_k = [0.0 for _ in range(K)] - teacher_logprobs_k = [0.0 for _ in range(K)] - valid_k = [0 for _ in range(K)] + # ------------------------------------------------------------------------- + # First pass: find max logits over K to implement logsumexp + # ------------------------------------------------------------------------- + max_val = tl.full([BLOCK_SIZE], -1e30, dtype=tl.float32) - # gather the top-K logits & teacher logprobs - for k in range(K): - # load student logit - student_val = tl.load( - student_logits_ptr + b_idx * seq_len * K + s_idx * K + k, - mask=(b_idx < B) and (s_idx < seq_len), - ) - teacher_val = tl.load( - teacher_logprobs_ptr + b_idx * seq_len * K + s_idx * K + k, - mask=(b_idx < B) and (s_idx < seq_len), - ) - # get mask - mask_val = tl.load( - mask_ptr + b_idx * seq_len * K + s_idx * K + k, - mask=(b_idx < B) and (s_idx < seq_len), - ) + # Python-level loops are allowed in Triton as long as the + # operations inside are Triton ops, not torch or Python math. + for k in range(K): + # pointer offset in the flattened [B, seq_len, K] = b_idx*(seq_len*K) + s_idx*K + k + offset_k = b_idx * (seq_len * K) + s_idx * K + k - student_logits_k[k] = student_val - teacher_logprobs_k[k] = teacher_val - valid_k[k] = mask_val + # load student logits, masked out-of-bounds with a large negative + # so they don't affect the max + student_val = tl.where( + mask_pos, + tl.load(student_logits_ptr + offset_k), + -1e30 + ) + # update running max + max_val = tl.where(student_val > max_val, student_val, max_val) - # track max for logsumexp (naive approach) - if student_val > logsumexp_val: - logsumexp_val = student_val + # ------------------------------------------------------------------------- + # Second pass: sum of exp(...) to complete logsumexp + # ------------------------------------------------------------------------- + exp_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for k in range(K): + offset_k = b_idx * (seq_len * K) + s_idx * K + k + student_val = tl.where( + mask_pos, + tl.load(student_logits_ptr + offset_k), + -1e30 + ) + # exponent + exponent = tl.exp(student_val - max_val) + exp_sum += exponent - # now compute logsumexp for the K student logits - # logsumexp = max_val + log(sum( exp(student_val - max_val) )) - exp_sum = 0.0 - for k in range(K): - if valid_k[k] != 0: # if valid - exp_val = tl.exp(student_logits_k[k] - logsumexp_val) - exp_sum += exp_val - # safe check - epsilon = 1e-8 # Small constant to prevent log(0) - exp_sum = tl.where(exp_sum == 0.0, epsilon, exp_sum) - logsumexp_val = logsumexp_val + tl.log(exp_sum) + # final logsumexp + logsumexp_val = max_val + tl.log(exp_sum) - # compute partial kl - # sum_{k in valid} p^T_k (log p^T_k - log p^S_k) - # teacher_probs_k = exp(teacher_logprobs_k) - for k in range(K): - if valid_k[k] != 0: # only valid tokens - teacher_prob = tl.exp(teacher_logprobs_k[k]) - student_logprob = student_logits_k[k] - logsumexp_val - kd_val = teacher_prob * (teacher_logprobs_k[k] - student_logprob) - kd_reg += kd_val + # ------------------------------------------------------------------------- + # Third pass: compute partial KL per position + # KL = sum_{k in valid} p^T_k * (teacher_logprobs_k - student_logprobs_k) + # + # - teacher_logprobs_k => t_log + # - teacher_prob_k = exp(t_log) + # - student_logprobs_k = s_val - logsumexp_val + # ------------------------------------------------------------------------- + for k in range(K): + offset_k = b_idx * (seq_len * K) + s_idx * K + k + # teacher logprobs + t_log = tl.where( + mask_pos, + tl.load(teacher_logprobs_ptr + offset_k), + -1e30 + ) + # teacher prob + t_prob = tl.exp(t_log) - # Write out partial kd for this block. We store a single partial sum in partial_kd_ptr - # We'll store it at partial_kd_ptr[pid] - # In real code, you might do an atomic add into partial_kd_ptr or a parallel reduction pass - # for now, let's just store it at index=pid - tl.store(partial_kd_ptr + pid, kd_reg) + # student logit + s_val = tl.where( + mask_pos, + tl.load(student_logits_ptr + offset_k), + -1e30 + ) + # student logprob + s_logprob = s_val - logsumexp_val + # local KL + kl_val = t_prob * (t_log - s_logprob) + + # also read mask to disable invalid tokens if mask is not purely sequence-based + valid_k = tl.load(mask_ptr + offset_k) + # if mask is bool => use 'valid_k != 0', if it's 0/1 => same + is_valid = (valid_k > 0) + + # zero out if either this index is out-of-bounds or mask is invalid + kl_val = tl.where(mask_pos & is_valid, kl_val, 0.0) + + # accumulate + kl_acc += kl_val + + # ------------------------------------------------------------------------- + # Store the partial KL in partial_kd_ptr for each element in idx. + # Later in Python, you can do partial_kd.sum() to get the total KL. + # ------------------------------------------------------------------------- + tl.store(partial_kd_ptr + idx, kl_acc, mask=mask_pos) + + +def kd_forward_pass_triton( + student_logits, # [B, seq_len, K] (already gathered) + teacher_logprobs, # [B, seq_len, K] + mask, # [B, seq_len, K] bool or 0/1 + BLOCK_SIZE=1024, +): + """ + Returns total KL (float). We do the sum on the Python side. + NOTE: No normalization is done here. + You might divide by `num_items_in_batch` or # valid tokens afterward. + """ + B, seq_len, K = student_logits.shape + # Flatten + student_logits_flat = student_logits.reshape(-1) + teacher_logprobs_flat = teacher_logprobs.reshape(-1) + mask_flat = mask.reshape(-1) + + total_positions = B * seq_len + # We'll store partial KL sums for each of the B*seq_len positions + partial_kd = torch.empty( + total_positions, dtype=student_logits.dtype, device=student_logits.device + ) + + # Grid config + grid = ((total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE,) + + kd_forward_kernel[grid]( + student_logits_flat, + teacher_logprobs_flat, + mask_flat, + partial_kd, + B, seq_len, K, + BLOCK_SIZE=BLOCK_SIZE + ) + + # Sum on CPU or GPU + kd_sum = partial_kd.sum() + return kd_sum class _KLDivergenceTritonFn(torch.autograd.Function): @staticmethod - def forward(ctx, student_logits, teacher_logprobs, mask, num_items_in_batch): + def forward(ctx, student_logits, teacher_logprobs, mask): """ - Inputs shape assumptions (after gather!): - - student_logits: [B, seq_len, K] - - teacher_logprobs: [B, seq_len, K] - - mask: [B, seq_len, K] (bool or 0/1) for valid tokens + student_logits: (B, seq_len, K) + teacher_logprobs: (B, seq_len, K) + mask: (B, seq_len, K) """ - B, seq_len, K = student_logits.shape - - # Prepare output buffer for partial sums - # We'll have BLOCK_SIZE define how many (batch*seq_len) items each block processes - # For simplicity, let's aim for one block per 1024 positions - BLOCK_SIZE = 1024 - # compute how many blocks we need - total_positions = B * seq_len - grid = ((total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE,) - - partial_kd = torch.empty( - grid[0], dtype=student_logits.dtype, device=student_logits.device - ) - - # Launch kernel - kd_forward_kernel[grid]( - student_logits, - teacher_logprobs, - mask, - partial_kd, - B, - seq_len, - K, - BLOCK_SIZE=BLOCK_SIZE, - ) - - # Sum partials on CPU or GPU - kd_sum = partial_kd.sum() - - # normalize - if num_items_in_batch is not None: - kd_loss = kd_sum / num_items_in_batch - else: - # Just average over all valid tokens; in practice you'd need the count of valid tokens - # For a quick approximation, let's do kd_sum / total_positions (or do a separate reduction on mask) - # This is a simplification. For correctness, you should count valid tokens in the kernel. - kd_loss = kd_sum / (total_positions * K) - - # Save context for backward - # Typically, you'd need to save the raw student_logits, teacher_logprobs, etc. for grad - # But be mindful of memory usage. We’ll demonstrate the minimal approach here: - ctx.save_for_backward( - student_logits, - teacher_logprobs, - mask, - torch.tensor(num_items_in_batch or 0), - ) - ctx.B = B - ctx.seq_len = seq_len - ctx.K = K - ctx.total_positions = total_positions - ctx.BLOCK_SIZE = BLOCK_SIZE + kd_sum = kd_forward_pass_triton(student_logits, teacher_logprobs, mask) + kd_loss = kd_sum # Not normalized here. You can do that externally. + # Save for backward + ctx.save_for_backward(student_logits, teacher_logprobs, mask) return kd_loss + @staticmethod def backward(ctx, grad_output): - """ - grad_output is dLoss/dOut (a scalar). - We want dLoss/dStudentLogits. - Recall that: - - Loss = sum_{valid k} p^T_k ( log p^T_k - (student_logits_k - logsumexp(student_logits_all_k)) ) - = sum_{valid k} p^T_k log p^T_k - sum_{valid k} p^T_k student_logits_k + sum_{valid k} p^T_k logsumexp(...) - - Let’s break down the derivative wrt student_logits_k. More precisely, from: - d/d student_logits_k [ - p^T_k student_logprobs_k ] - you get: - - p^T_k * ( d/d student_logits_k [ student_logits_k - logsumexp(...) ] ) - = - p^T_k * (1 - p^S_k) - = p^T_k * p^S_k - p^T_k - = p^S_k * p^T_k - p^T_k - = p^T_k( p^S_k - 1 ) - - In practice, we also must handle the mask. - A real implementation typically re-runs the gather & logsumexp calculations or caches them in forward(). - For brevity, we do a naive approach in PyTorch (not Triton) for the backward. - For maximum speed, you'd do a second Triton kernel. - - We'll do a minimal approach here: recompute everything on the host side or a pure PyTorch pass. - """ - student_logits, teacher_logprobs, mask, num_items_in_batch_t = ctx.saved_tensors - num_items_in_batch = int(num_items_in_batch_t.item()) - B, seq_len, K = ctx.B, ctx.seq_len, ctx.K - - # We can either replicate the entire forward logic in PyTorch for gradient - # or do a second Triton pass. Here, let's do it in PyTorch for clarity. - - # 1) compute logsumexp of student_logits_k for each [b, s] - # 2) compute p^S_k - # 3) compute p^T_k from teacher_logprobs - # 4) dLoss/dStudentLogits = grad_output * p^T_k ( p^S_k - 1 ), masked - # 5) sum or gather the final gradient + # We'll do naive PyTorch re-computation for gradient wrt student_logits + student_logits, teacher_logprobs, mask = ctx.saved_tensors + # grad_output is dLoss/dOut => a scalar + # Let’s compute dLoss/dStudentLogits with the same formula as your original code with torch.enable_grad(): - # treat student_logits as if it requires grad stl = student_logits.clone().detach().requires_grad_(True) - # compute logsumexp along K - logsumexp_val = torch.logsumexp( - stl, dim=-1, keepdim=True - ) # [B, seq_len, 1] - student_logprobs_topk = stl - logsumexp_val - teacher_probs = teacher_logprobs.exp() - # p^S_k - p_s = student_logprobs_topk.exp() + t_log = teacher_logprobs + # mask might be bool or 0/1 + # compute logsumexp + lse = torch.logsumexp(stl, dim=-1, keepdim=True) + s_logprob = stl - lse + t_prob = t_log.exp() - # forward kl = sum p^T_k ( teacher_logprobs_k - student_logprobs_topk ) - # derivative wrt stl = p^T_k( p^S_k - 1 ) - grad_stl = teacher_probs * (p_s - 1.0) - # respect the mask - grad_stl = grad_stl * mask # zero out invalid + # forward KL = sum_{k} p^T_k ( t_log_k - s_logprob_k ) + kl_val = t_prob * (t_log - s_logprob) + # mask out + kl_val = kl_val * mask # zero out invalid - # sum or average - if num_items_in_batch != 0: - grad_stl = grad_stl / num_items_in_batch - else: - grad_stl = grad_stl / (B * seq_len * K) # fallback + kd_loss = kl_val.sum() + # now compute dLoss/d stl + grad_stl = torch.autograd.grad(kd_loss, stl, grad_output=grad_output)[0] - # multiply by upstream grad_output - grad_stl = grad_stl * grad_output - - return grad_stl, None, None, None + return grad_stl, None, None def kd_loss_triton( @@ -274,5 +253,5 @@ def kd_loss_triton( called gather on student_logits -> shape [B, seq_len, K]. """ return _KLDivergenceTritonFn.apply( - student_logits, teacher_logprobs, mask, num_items_in_batch + student_logits, teacher_logprobs, mask, # num_items_in_batch )