no log etc

This commit is contained in:
Wing Lian
2024-12-21 13:54:21 -05:00
parent 0da2b7c7cc
commit 9c0470130b

View File

@@ -2,7 +2,6 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
# -------------------------------------------------------- # --------------------------------------------------------
# Triton Kernel for forward pass # Triton Kernel for forward pass
# -------------------------------------------------------- # --------------------------------------------------------
@@ -30,21 +29,21 @@ import triton.language as tl
@triton.jit @triton.jit
def kd_forward_kernel( def kd_forward_kernel(
student_logits_ptr, # float32[B, seq_len, K] after gather student_logits_ptr, # float32[B, seq_len, K] after gather
teacher_logprobs_ptr, # float32[B, seq_len, K] teacher_logprobs_ptr, # float32[B, seq_len, K]
mask_ptr, # bool[B, seq_len, K] or int8 mask_ptr, # bool[B, seq_len, K] or int8
partial_kd_ptr, # float32[B, seq_len] (accumulator) partial_kd_ptr, # float32[B, seq_len] (accumulator)
B, # total batch size B, # total batch size
seq_len, # total sequence length from teacher seq_len, # total sequence length from teacher
K, # top-K from teacher K, # top-K from teacher
BLOCK_SIZE: tl.constexpr # how many tokens per block in dimension0 BLOCK_SIZE: tl.constexpr, # how many tokens per block in dimension0
): ):
# program_id is the global index for each block # program_id is the global index for each block
pid = tl.program_id(0) pid = tl.program_id(0)
# Each block handles a range of seq positions in [0..B*seq_len) # Each block handles a range of seq positions in [0..B*seq_len)
block_start = pid * BLOCK_SIZE block_start = pid * BLOCK_SIZE
block_end = tl.min((pid+1)*BLOCK_SIZE, B * seq_len) block_end = tl.min((pid + 1) * BLOCK_SIZE, B * seq_len)
length = block_end - block_start length = block_end - block_start
# Offsets for indexing # Offsets for indexing
@@ -69,7 +68,7 @@ def kd_forward_kernel(
# For K top tokens, read the relevant student logits and teacher logprobs # For K top tokens, read the relevant student logits and teacher logprobs
# We'll load them in a small loop: # We'll load them in a small loop:
logsumexp_val = float('-inf') logsumexp_val = float("-inf")
# We'll store them in a local array for a second pass # We'll store them in a local array for a second pass
student_logits_k = [0.0 for _ in range(K)] student_logits_k = [0.0 for _ in range(K)]
teacher_logprobs_k = [0.0 for _ in range(K)] teacher_logprobs_k = [0.0 for _ in range(K)]
@@ -79,26 +78,17 @@ def kd_forward_kernel(
for k in range(K): for k in range(K):
# load student logit # load student logit
student_val = tl.load( student_val = tl.load(
student_logits_ptr student_logits_ptr + b_idx * seq_len * K + s_idx * K + k,
+ b_idx*seq_len*K mask=(b_idx < B) and (s_idx < seq_len),
+ s_idx*K
+ k,
mask=(b_idx < B) and (s_idx < seq_len)
) )
teacher_val = tl.load( teacher_val = tl.load(
teacher_logprobs_ptr teacher_logprobs_ptr + b_idx * seq_len * K + s_idx * K + k,
+ b_idx*seq_len*K mask=(b_idx < B) and (s_idx < seq_len),
+ s_idx*K
+ k,
mask=(b_idx < B) and (s_idx < seq_len)
) )
# get mask # get mask
mask_val = tl.load( mask_val = tl.load(
mask_ptr mask_ptr + b_idx * seq_len * K + s_idx * K + k,
+ b_idx*seq_len*K mask=(b_idx < B) and (s_idx < seq_len),
+ s_idx*K
+ k,
mask=(b_idx < B) and (s_idx < seq_len)
) )
student_logits_k[k] = student_val student_logits_k[k] = student_val
@@ -119,7 +109,7 @@ def kd_forward_kernel(
# safe check # safe check
if exp_sum == 0.0: if exp_sum == 0.0:
exp_sum = 1e-8 exp_sum = 1e-8
logsumexp_val = logsumexp_val + float(torch.log(torch.tensor(exp_sum))) logsumexp_val = logsumexp_val + tl.log(torch.tensor(exp_sum))
# compute partial kl # compute partial kl
# sum_{k in valid} p^T_k (log p^T_k - log p^S_k) # sum_{k in valid} p^T_k (log p^T_k - log p^S_k)
@@ -155,7 +145,7 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
# compute how many blocks we need # compute how many blocks we need
total_positions = B * seq_len total_positions = B * seq_len
grid = ( (total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE , ) grid = ((total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE,)
partial_kd = torch.empty( partial_kd = torch.empty(
grid[0], dtype=student_logits.dtype, device=student_logits.device grid[0], dtype=student_logits.dtype, device=student_logits.device
@@ -167,8 +157,10 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
teacher_logprobs, teacher_logprobs,
mask, mask,
partial_kd, partial_kd,
B, seq_len, K, B,
BLOCK_SIZE=BLOCK_SIZE seq_len,
K,
BLOCK_SIZE=BLOCK_SIZE,
) )
# Sum partials on CPU or GPU # Sum partials on CPU or GPU
@@ -186,7 +178,12 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
# Save context for backward # Save context for backward
# Typically, you'd need to save the raw student_logits, teacher_logprobs, etc. for grad # Typically, you'd need to save the raw student_logits, teacher_logprobs, etc. for grad
# But be mindful of memory usage. Well demonstrate the minimal approach here: # But be mindful of memory usage. Well demonstrate the minimal approach here:
ctx.save_for_backward(student_logits, teacher_logprobs, mask, torch.tensor(num_items_in_batch or 0)) ctx.save_for_backward(
student_logits,
teacher_logprobs,
mask,
torch.tensor(num_items_in_batch or 0),
)
ctx.B = B ctx.B = B
ctx.seq_len = seq_len ctx.seq_len = seq_len
ctx.K = K ctx.K = K
@@ -238,7 +235,9 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
# treat student_logits as if it requires grad # treat student_logits as if it requires grad
stl = student_logits.clone().detach().requires_grad_(True) stl = student_logits.clone().detach().requires_grad_(True)
# compute logsumexp along K # compute logsumexp along K
logsumexp_val = torch.logsumexp(stl, dim=-1, keepdim=True) # [B, seq_len, 1] logsumexp_val = torch.logsumexp(
stl, dim=-1, keepdim=True
) # [B, seq_len, 1]
student_logprobs_topk = stl - logsumexp_val student_logprobs_topk = stl - logsumexp_val
teacher_probs = teacher_logprobs.exp() teacher_probs = teacher_logprobs.exp()
# p^S_k # p^S_k
@@ -266,7 +265,7 @@ def kd_loss_triton(
student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K
teacher_logprobs, teacher_logprobs,
mask, mask,
num_items_in_batch=None num_items_in_batch=None,
): ):
""" """
Wrapper that calls our Triton-based forward+backward for KD. Wrapper that calls our Triton-based forward+backward for KD.
@@ -274,4 +273,6 @@ def kd_loss_triton(
or inside a separate kernel. This function expects that you've *already* or inside a separate kernel. This function expects that you've *already*
called gather on student_logits -> shape [B, seq_len, K]. called gather on student_logits -> shape [B, seq_len, K].
""" """
return _KLDivergenceTritonFn.apply(student_logits, teacher_logprobs, mask, num_items_in_batch) return _KLDivergenceTritonFn.apply(
student_logits, teacher_logprobs, mask, num_items_in_batch
)