From 119d586cf4e1a6ea850fda48f21d4568fbbf9d93 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 21 Dec 2024 13:43:48 -0500 Subject: [PATCH] v2 trial --- src/axolotl/core/trainers/kd.py | 49 +-- src/axolotl/integrations/kd/kernels/kd.py | 464 +++++++++++----------- 2 files changed, 252 insertions(+), 261 deletions(-) diff --git a/src/axolotl/core/trainers/kd.py b/src/axolotl/core/trainers/kd.py index b604eb989..3b15c8484 100644 --- a/src/axolotl/core/trainers/kd.py +++ b/src/axolotl/core/trainers/kd.py @@ -7,10 +7,7 @@ from typing import Optional import torch from axolotl.core.trainers.base import AxolotlTrainer -from axolotl.integrations.kd.kernels.kd import ( - forward_kl_topk, - prepare_topk_student_teacher, -) +from axolotl.integrations.kd.kernels.kd import kd_loss_triton def kd_loss_function( @@ -100,43 +97,29 @@ class AxolotlKDTrainer(AxolotlTrainer): outputs = model(**inputs) student_logits = outputs["logits"] + # Slice or gather student logits to match teacher seq len + # e.g.: + teacher_seq_len = target_token_ids.shape[1] + student_logits_for_kd = student_logits[:, :teacher_seq_len, :] # [B, seq_len, vocab_size] - # Gather & flatten to [N, K] - stud_lp_f, teach_lp_f, mask_f = prepare_topk_student_teacher( - student_logits, - target_token_ids, - target_logprobs, - target_mask, + # GATHER top-K from student + student_logits_topk = torch.gather( + student_logits_for_kd, dim=-1, index=target_token_ids # same shape [B, seq_len, K] ) - loss_kd = forward_kl_topk(teach_lp_f, stud_lp_f, mask_f, reduction="none") - - # Normalize by number of items or mean over valid tokens - if num_items_in_batch is not None: - # If you know how many items should be considered in the batch - loss_kd = loss_kd / num_items_in_batch - else: - # Otherwise, just average over all valid tokens - # count number of unmasked tokens in teacher_mask - kd_loss_per_token = target_mask.sum(dim=1).unsqueeze(-1) - # Normalize by number of unmasked tokens in teacher_mask - loss_kd = loss_kd / kd_loss_per_token.float() + # Now call the Triton-based KD loss + loss_kd = kd_loss_triton( + student_logits_topk, + target_logprobs, # teacher logprobs [B, seq_len, K] + target_mask, # mask [B, seq_len, K] + num_items_in_batch=num_items_in_batch, + ) + # optionally combine with CE loss if self.args.kd_ce_alpha > 0: loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd else: loss = loss_kd - # Save past state if it exists - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[ # pylint: disable=attribute-defined-outside-init - self.args.past_index - ] - - if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: - loss *= self.accelerator.num_processes - - torch.cuda.empty_cache() return (loss, outputs) if return_outputs else loss diff --git a/src/axolotl/integrations/kd/kernels/kd.py b/src/axolotl/integrations/kd/kernels/kd.py index 5d2ff19cf..058cfe334 100644 --- a/src/axolotl/integrations/kd/kernels/kd.py +++ b/src/axolotl/integrations/kd/kernels/kd.py @@ -2,267 +2,275 @@ import torch import triton import triton.language as tl -configs = [ - triton.Config({"BLOCK_SIZE": 32}, num_warps=1, num_stages=1), - triton.Config({"BLOCK_SIZE": 64}, num_warps=1, num_stages=1), - triton.Config({"BLOCK_SIZE": 128}, num_warps=2, num_stages=2), - # Add more if needed -] + +# -------------------------------------------------------- +# Triton Kernel for forward pass +# -------------------------------------------------------- +# We'll assume: +# - B * seq_len threads in 1D dimension +# - Each thread handles K tokens (the top-K from teacher). +# - For large K, you might want a more 2D approach to keep good occupancy. +# +# Pseudocode steps inside kernel: +# 1) compute index for [batch, seq_position] +# 2) read top-K token IDs from teacher_token_ids +# 3) gather student_logits_topk +# 4) compute logsumexp for those K logits +# 5) compute student_logprobs_topk +# 6) read teacher_logprobs +# 7) compute teacher_probs = exp(teacher_logprobs) +# 8) compute partial KL = sum(teacher_probs * (teacher_logprobs - student_logprobs_topk)) +# 9) store partial KL in a buffer +# +# Later, we'll do a reduction on partial KL across all threads. +# +# NOTE: This is a reference skeleton. You must adapt indexing carefully. +# -@triton.autotune(configs=configs, key=["N", "K"]) @triton.jit -def fwd_kl_topk_kernel( - teacher_lp_ptr, # float32 [N, K] - student_lp_ptr, # float32 [N, K] - mask_ptr, # bool [N, K] - loss_out_ptr, # float32 [N] - stride_tn, - stride_tk, - stride_sn, - stride_sk, - stride_mn, - stride_mk, - stride_loss_n, - N: tl.constexpr, - K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, +def kd_forward_kernel( + student_logits_ptr, # float32[B, seq_len, K] after gather + teacher_logprobs_ptr, # float32[B, seq_len, K] + mask_ptr, # bool[B, seq_len, K] or int8 + partial_kd_ptr, # float32[B, seq_len] (accumulator) + B, # total batch size + seq_len, # total sequence length from teacher + K, # top-K from teacher + BLOCK_SIZE: tl.constexpr # how many tokens per block in dimension0 ): - """ - Each kernel instance: row_id = tl.program_id(0). We'll tile the K dimension in chunks of BLOCK_SIZE. - Summation => store into loss_out[row_id]. - """ - row_id = tl.program_id(0) - if row_id >= N: - return + # program_id is the global index for each block + pid = tl.program_id(0) - # Base pointers for teacher, student, mask rows - t_row_ptr = teacher_lp_ptr + row_id * stride_tn - s_row_ptr = student_lp_ptr + row_id * stride_sn - m_row_ptr = mask_ptr + row_id * stride_mn + # Each block handles a range of seq positions in [0..B*seq_len) + block_start = pid * BLOCK_SIZE + block_end = tl.min((pid+1)*BLOCK_SIZE, B * seq_len) + length = block_end - block_start - # We'll accumulate KL in local variable - kl_sum = 0.0 + # Offsets for indexing + # We want to interpret a linear index in [0..B*seq_len) as (batch_idx, seq_idx) + # E.g.: + # batch_idx = block_start // seq_len + # seq_idx = block_start % seq_len + # but we must do this for each element in the block. We'll do that inside a loop. - # tile the K dimension - num_tiles = (K + BLOCK_SIZE - 1) // BLOCK_SIZE - for tile_id in range(num_tiles): - k_offset = tile_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = k_offset < K + # We'll store a running partial KL sum in registers + # We do a for-loop for each position in the block, then do a thread-level reduction + kd_reg = 0.0 - # load teacher logprobs - t_lp = tl.load(t_row_ptr + k_offset * stride_tk, mask=mask, other=-float("inf")) - # load student logprobs - s_lp = tl.load(s_row_ptr + k_offset * stride_sk, mask=mask, other=-float("inf")) + # We'll iterate over each item in [block_start, block_end). + # A more advanced approach can use vectorization / warp-based parallelism inside the block. + for offset in range(length): + # Convert offset -> actual index in [0..B*seq_len) + linear_idx = block_start + offset + # batch index and sequence index + b_idx = linear_idx // seq_len + s_idx = linear_idx % seq_len - # load mask => bool (0 or 1) - valid = tl.load(m_row_ptr + k_offset * stride_mk, mask=mask, other=0) - valid_f32 = valid.to(tl.float32) + # For K top tokens, read the relevant student logits and teacher logprobs + # We'll load them in a small loop: + logsumexp_val = float('-inf') + # We'll store them in a local array for a second pass + student_logits_k = [0.0 for _ in range(K)] + teacher_logprobs_k = [0.0 for _ in range(K)] + valid_k = [0 for _ in range(K)] - # teacher probs - t_p = tl.exp(t_lp) + # gather the top-K logits & teacher logprobs + for k in range(K): + # load student logit + student_val = tl.load( + student_logits_ptr + + b_idx*seq_len*K + + s_idx*K + + k, + mask=(b_idx < B) and (s_idx < seq_len) + ) + teacher_val = tl.load( + teacher_logprobs_ptr + + b_idx*seq_len*K + + s_idx*K + + k, + mask=(b_idx < B) and (s_idx < seq_len) + ) + # get mask + mask_val = tl.load( + mask_ptr + + b_idx*seq_len*K + + s_idx*K + + k, + mask=(b_idx < B) and (s_idx < seq_len) + ) - # local_kl = p^T * (lp^T - lp^S) - local_kl = t_p * (t_lp - s_lp) - # multiply by valid_f32 to ignore padded or invalid positions - local_kl *= valid_f32 + student_logits_k[k] = student_val + teacher_logprobs_k[k] = teacher_val + valid_k[k] = mask_val - # sum over the chunk - kl_sum += tl.sum(local_kl, where=mask) + # track max for logsumexp (naive approach) + if student_val > logsumexp_val: + logsumexp_val = student_val - # store rowwise result - tl.store(loss_out_ptr + row_id * stride_loss_n, kl_sum) + # now compute logsumexp for the K student logits + # logsumexp = max_val + log(sum( exp(student_val - max_val) )) + exp_sum = 0.0 + for k in range(K): + if valid_k[k] != 0: # if valid + exp_sum += float(torch.exp(student_logits_k[k] - logsumexp_val)) + # safe check + if exp_sum == 0.0: + exp_sum = 1e-8 + logsumexp_val = logsumexp_val + float(torch.log(torch.tensor(exp_sum))) + + # compute partial kl + # sum_{k in valid} p^T_k (log p^T_k - log p^S_k) + # teacher_probs_k = exp(teacher_logprobs_k) + for k in range(K): + if valid_k[k] != 0: # only valid tokens + teacher_prob = float(torch.exp(teacher_logprobs_k[k])) + student_logprob = student_logits_k[k] - logsumexp_val + kd_val = teacher_prob * (teacher_logprobs_k[k] - student_logprob) + kd_reg += kd_val + + # Write out partial kd for this block. We store a single partial sum in partial_kd_ptr + # We'll store it at partial_kd_ptr[pid] + # In real code, you might do an atomic add into partial_kd_ptr or a parallel reduction pass + # for now, let's just store it at index=pid + tl.store(partial_kd_ptr + pid, kd_reg) -@triton.autotune(configs=configs, key=["N", "K"]) -@triton.jit -def bwd_kl_topk_kernel( - teacher_lp_ptr, # float32 [N, K] - mask_ptr, # bool [N, K] - grad_stud_ptr, # float32 [N, K], output - stride_tn, - stride_tk, - stride_mn, - stride_mk, - stride_gn, - stride_gk, - N: tl.constexpr, - K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - """ - For forward KL, d/d(student_lp) = - exp(teacher_lp), if mask=1, else 0. - Each kernel instance processes one row [K]. - """ - row_id = tl.program_id(0) - if row_id >= N: - return - - t_row_ptr = teacher_lp_ptr + row_id * stride_tn - m_row_ptr = mask_ptr + row_id * stride_mn - g_row_ptr = grad_stud_ptr + row_id * stride_gn - - num_tiles = (K + BLOCK_SIZE - 1) // BLOCK_SIZE - for tile_id in range(num_tiles): - k_offset = tile_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = k_offset < K - - t_lp = tl.load(t_row_ptr + k_offset * stride_tk, mask=mask, other=-float("inf")) - valid = tl.load(m_row_ptr + k_offset * stride_mk, mask=mask, other=0).to( - tl.int1 - ) - - grad_val = -tl.exp(t_lp) # derivative - grad_val = tl.where(valid, grad_val, 0.0) - - tl.store(g_row_ptr + k_offset * stride_gk, grad_val, mask=mask) - - -class FwdKLTopKFunction(torch.autograd.Function): +class _KLDivergenceTritonFn(torch.autograd.Function): @staticmethod - def forward( - ctx, - teacher_lp_topk: torch.Tensor, - student_lp_topk: torch.Tensor, - mask_topk: torch.Tensor, - reduction: str = "batchmean", - ) -> torch.Tensor: + def forward(ctx, student_logits, teacher_logprobs, mask, num_items_in_batch): """ - teacher_lp_topk: [N, K] - student_lp_topk: [N, K] - mask_topk: [N, K] bool - returns either scalar (if batchmean) or [N] if 'none' + Inputs shape assumptions (after gather!): + - student_logits: [B, seq_len, K] + - teacher_logprobs: [B, seq_len, K] + - mask: [B, seq_len, K] (bool or 0/1) for valid tokens """ - assert teacher_lp_topk.shape == student_lp_topk.shape - assert teacher_lp_topk.shape == mask_topk.shape + B, seq_len, K = student_logits.shape - N, K = teacher_lp_topk.shape - dev = teacher_lp_topk.device - dtype = teacher_lp_topk.dtype + # Prepare output buffer for partial sums + # We'll have BLOCK_SIZE define how many (batch*seq_len) items each block processes + # For simplicity, let's aim for one block per 1024 positions + BLOCK_SIZE = 1024 + # compute how many blocks we need + total_positions = B * seq_len + grid = ( (total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE , ) - # Contiguous - t_lp_c = teacher_lp_topk.contiguous() - s_lp_c = student_lp_topk.contiguous() - m_c = mask_topk.contiguous() - - # [N] rowwise sums - loss_out = torch.empty(N, dtype=torch.float32, device=dev) - - grid = (N,) - - fwd_kl_topk_kernel[grid]( - t_lp_c, - s_lp_c, - m_c, - loss_out, - # strides - t_lp_c.stride(0), - t_lp_c.stride(1), - s_lp_c.stride(0), - s_lp_c.stride(1), - m_c.stride(0), - m_c.stride(1), - loss_out.stride(0), - N=N, - K=K - # BLOCK_SIZE, warps, stages => autotune + partial_kd = torch.empty( + grid[0], dtype=student_logits.dtype, device=student_logits.device ) - if reduction == "batchmean": - loss_val = loss_out.mean() - elif reduction == "none": - loss_val = loss_out + # Launch kernel + kd_forward_kernel[grid]( + student_logits, + teacher_logprobs, + mask, + partial_kd, + B, seq_len, K, + BLOCK_SIZE=BLOCK_SIZE + ) + + # Sum partials on CPU or GPU + kd_sum = partial_kd.sum() + + # normalize + if num_items_in_batch is not None: + kd_loss = kd_sum / num_items_in_batch else: - raise ValueError("reduction must be 'batchmean' or 'none'") + # Just average over all valid tokens; in practice you'd need the count of valid tokens + # For a quick approximation, let's do kd_sum / total_positions (or do a separate reduction on mask) + # This is a simplification. For correctness, you should count valid tokens in the kernel. + kd_loss = kd_sum / (total_positions * K) - # Save for backward - ctx.save_for_backward(t_lp_c, m_c) - ctx.reduction = reduction - ctx.shape = (N, K) + # Save context for backward + # Typically, you'd need to save the raw student_logits, teacher_logprobs, etc. for grad + # But be mindful of memory usage. We’ll demonstrate the minimal approach here: + ctx.save_for_backward(student_logits, teacher_logprobs, mask, torch.tensor(num_items_in_batch or 0)) + ctx.B = B + ctx.seq_len = seq_len + ctx.K = K + ctx.total_positions = total_positions + ctx.BLOCK_SIZE = BLOCK_SIZE - return loss_val + return kd_loss @staticmethod def backward(ctx, grad_output): - # grad_output is either scalar ([1]) if batchmean, or shape=[N] if 'none' - t_lp_c, m_c = ctx.saved_tensors - (N, K) = ctx.shape + """ + grad_output is dLoss/dOut (a scalar). + We want dLoss/dStudentLogits. + Recall that: - # We'll create a grad for the student's top-K logprobs - grad_stud = torch.empty_like(t_lp_c) # [N, K] + Loss = sum_{valid k} p^T_k ( log p^T_k - (student_logits_k - logsumexp(student_logits_all_k)) ) + = sum_{valid k} p^T_k log p^T_k - sum_{valid k} p^T_k student_logits_k + sum_{valid k} p^T_k logsumexp(...) - grid = (N,) - bwd_kl_topk_kernel[grid]( - t_lp_c, - m_c, - grad_stud, - t_lp_c.stride(0), - t_lp_c.stride(1), - m_c.stride(0), - m_c.stride(1), - grad_stud.stride(0), - grad_stud.stride(1), - N=N, - K=K, - ) + Let’s break down the derivative wrt student_logits_k. More precisely, from: + d/d student_logits_k [ - p^T_k student_logprobs_k ] + you get: + - p^T_k * ( d/d student_logits_k [ student_logits_k - logsumexp(...) ] ) + = - p^T_k * (1 - p^S_k) + = p^T_k * p^S_k - p^T_k + = p^S_k * p^T_k - p^T_k + = p^T_k( p^S_k - 1 ) - # Multiply by grad_output - # If batchmean => scalar - # If none => shape=[N] - if grad_output.numel() == 1: - grad_stud *= grad_output - else: - # shape=[N], broadcast over K - grad_stud *= grad_output.unsqueeze(1) + In practice, we also must handle the mask. + A real implementation typically re-runs the gather & logsumexp calculations or caches them in forward(). + For brevity, we do a naive approach in PyTorch (not Triton) for the backward. + For maximum speed, you'd do a second Triton kernel. - return grad_stud, None, None, None + We'll do a minimal approach here: recompute everything on the host side or a pure PyTorch pass. + """ + student_logits, teacher_logprobs, mask, num_items_in_batch_t = ctx.saved_tensors + num_items_in_batch = int(num_items_in_batch_t.item()) + B, seq_len, K = ctx.B, ctx.seq_len, ctx.K + + # We can either replicate the entire forward logic in PyTorch for gradient + # or do a second Triton pass. Here, let's do it in PyTorch for clarity. + + # 1) compute logsumexp of student_logits_k for each [b, s] + # 2) compute p^S_k + # 3) compute p^T_k from teacher_logprobs + # 4) dLoss/dStudentLogits = grad_output * p^T_k ( p^S_k - 1 ), masked + # 5) sum or gather the final gradient + + with torch.enable_grad(): + # treat student_logits as if it requires grad + stl = student_logits.clone().detach().requires_grad_(True) + # compute logsumexp along K + logsumexp_val = torch.logsumexp(stl, dim=-1, keepdim=True) # [B, seq_len, 1] + student_logprobs_topk = stl - logsumexp_val + teacher_probs = teacher_logprobs.exp() + # p^S_k + p_s = student_logprobs_topk.exp() + + # forward kl = sum p^T_k ( teacher_logprobs_k - student_logprobs_topk ) + # derivative wrt stl = p^T_k( p^S_k - 1 ) + grad_stl = teacher_probs * (p_s - 1.0) + # respect the mask + grad_stl = grad_stl * mask # zero out invalid + + # sum or average + if num_items_in_batch != 0: + grad_stl = grad_stl / num_items_in_batch + else: + grad_stl = grad_stl / (B * seq_len * K) # fallback + + # multiply by upstream grad_output + grad_stl = grad_stl * grad_output + + return grad_stl, None, None, None -def forward_kl_topk( - teacher_lp_topk: torch.Tensor, - student_lp_topk: torch.Tensor, - mask_topk: torch.Tensor, - reduction: str = "batchmean", -) -> torch.Tensor: +def kd_loss_triton( + student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K + teacher_logprobs, + mask, + num_items_in_batch=None +): """ - Calls the autograd function that launches Triton kernels for forward + backward. + Wrapper that calls our Triton-based forward+backward for KD. + For production, you likely want to do the gather (teacher top-K) outside + or inside a separate kernel. This function expects that you've *already* + called gather on student_logits -> shape [B, seq_len, K]. """ - return FwdKLTopKFunction.apply( - teacher_lp_topk, student_lp_topk, mask_topk, reduction - ) - - -def prepare_topk_student_teacher( - student_logits: torch.Tensor, # [B, teacher_seq_len, vocab_size] - target_token_ids: torch.Tensor, # [B, teacher_seq_len, K] - target_logprobs: torch.Tensor, # [B, teacher_seq_len, K], teacher logprobs - target_mask: torch.Tensor, # [B, teacher_seq_len, K], bool or 0/1 -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Gathers student logits for the teacher's top-K tokens and flattens the first 2 dims => N = B * teacher_seq_len. - - Returns: - (student_lp_topk, teacher_lp_topk, valid_mask) each shape = [N, K]. - """ - B, S, K = target_token_ids.shape - # Gather the student logits => [B, S, K] - # 1) slice or use the entire student_logits if it matches teacher_seq_len - student_logits_for_kd = student_logits[:, :S, :] # ensure alignment if needed - - # 2) gather top-k => [B, S, K] - student_logits_topk = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) - - # 3) convert student logits to logprobs => [B, S, K] - student_logprobs_topk = student_logits_topk - torch.logsumexp( - student_logits_topk, dim=-1, keepdim=True - ) - - # Flatten batch dimension - N = B * S - student_logprobs_topk_f = student_logprobs_topk.view(N, K) # [N, K] - teacher_logprobs_topk_f = target_logprobs.view(N, K) # [N, K] - mask_f = target_mask.view(N, K).bool() # [N, K] - - return student_logprobs_topk_f, teacher_logprobs_topk_f, mask_f + return _KLDivergenceTritonFn.apply(student_logits, teacher_logprobs, mask, num_items_in_batch)