no log etc

This commit is contained in:
Wing Lian
2024-12-21 13:54:21 -05:00
parent 5a7d6f6175
commit 9b1164b841

View File

@@ -2,7 +2,6 @@ import torch
import triton
import triton.language as tl
# --------------------------------------------------------
# Triton Kernel for forward pass
# --------------------------------------------------------
@@ -30,21 +29,21 @@ import triton.language as tl
@triton.jit
def kd_forward_kernel(
student_logits_ptr, # float32[B, seq_len, K] after gather
teacher_logprobs_ptr, # float32[B, seq_len, K]
mask_ptr, # bool[B, seq_len, K] or int8
partial_kd_ptr, # float32[B, seq_len] (accumulator)
B, # total batch size
seq_len, # total sequence length from teacher
K, # top-K from teacher
BLOCK_SIZE: tl.constexpr # how many tokens per block in dimension0
student_logits_ptr, # float32[B, seq_len, K] after gather
teacher_logprobs_ptr, # float32[B, seq_len, K]
mask_ptr, # bool[B, seq_len, K] or int8
partial_kd_ptr, # float32[B, seq_len] (accumulator)
B, # total batch size
seq_len, # total sequence length from teacher
K, # top-K from teacher
BLOCK_SIZE: tl.constexpr, # how many tokens per block in dimension0
):
# program_id is the global index for each block
pid = tl.program_id(0)
# Each block handles a range of seq positions in [0..B*seq_len)
block_start = pid * BLOCK_SIZE
block_end = tl.min((pid+1)*BLOCK_SIZE, B * seq_len)
block_end = tl.min((pid + 1) * BLOCK_SIZE, B * seq_len)
length = block_end - block_start
# Offsets for indexing
@@ -69,7 +68,7 @@ def kd_forward_kernel(
# For K top tokens, read the relevant student logits and teacher logprobs
# We'll load them in a small loop:
logsumexp_val = float('-inf')
logsumexp_val = float("-inf")
# We'll store them in a local array for a second pass
student_logits_k = [0.0 for _ in range(K)]
teacher_logprobs_k = [0.0 for _ in range(K)]
@@ -79,26 +78,17 @@ def kd_forward_kernel(
for k in range(K):
# load student logit
student_val = tl.load(
student_logits_ptr
+ b_idx*seq_len*K
+ s_idx*K
+ k,
mask=(b_idx < B) and (s_idx < seq_len)
student_logits_ptr + b_idx * seq_len * K + s_idx * K + k,
mask=(b_idx < B) and (s_idx < seq_len),
)
teacher_val = tl.load(
teacher_logprobs_ptr
+ b_idx*seq_len*K
+ s_idx*K
+ k,
mask=(b_idx < B) and (s_idx < seq_len)
teacher_logprobs_ptr + b_idx * seq_len * K + s_idx * K + k,
mask=(b_idx < B) and (s_idx < seq_len),
)
# get mask
mask_val = tl.load(
mask_ptr
+ b_idx*seq_len*K
+ s_idx*K
+ k,
mask=(b_idx < B) and (s_idx < seq_len)
mask_ptr + b_idx * seq_len * K + s_idx * K + k,
mask=(b_idx < B) and (s_idx < seq_len),
)
student_logits_k[k] = student_val
@@ -119,7 +109,7 @@ def kd_forward_kernel(
# safe check
if exp_sum == 0.0:
exp_sum = 1e-8
logsumexp_val = logsumexp_val + float(torch.log(torch.tensor(exp_sum)))
logsumexp_val = logsumexp_val + tl.log(torch.tensor(exp_sum))
# compute partial kl
# sum_{k in valid} p^T_k (log p^T_k - log p^S_k)
@@ -155,7 +145,7 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
BLOCK_SIZE = 1024
# compute how many blocks we need
total_positions = B * seq_len
grid = ( (total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE , )
grid = ((total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE,)
partial_kd = torch.empty(
grid[0], dtype=student_logits.dtype, device=student_logits.device
@@ -167,8 +157,10 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
teacher_logprobs,
mask,
partial_kd,
B, seq_len, K,
BLOCK_SIZE=BLOCK_SIZE
B,
seq_len,
K,
BLOCK_SIZE=BLOCK_SIZE,
)
# Sum partials on CPU or GPU
@@ -186,7 +178,12 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
# Save context for backward
# Typically, you'd need to save the raw student_logits, teacher_logprobs, etc. for grad
# But be mindful of memory usage. Well demonstrate the minimal approach here:
ctx.save_for_backward(student_logits, teacher_logprobs, mask, torch.tensor(num_items_in_batch or 0))
ctx.save_for_backward(
student_logits,
teacher_logprobs,
mask,
torch.tensor(num_items_in_batch or 0),
)
ctx.B = B
ctx.seq_len = seq_len
ctx.K = K
@@ -238,7 +235,9 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
# treat student_logits as if it requires grad
stl = student_logits.clone().detach().requires_grad_(True)
# compute logsumexp along K
logsumexp_val = torch.logsumexp(stl, dim=-1, keepdim=True) # [B, seq_len, 1]
logsumexp_val = torch.logsumexp(
stl, dim=-1, keepdim=True
) # [B, seq_len, 1]
student_logprobs_topk = stl - logsumexp_val
teacher_probs = teacher_logprobs.exp()
# p^S_k
@@ -266,7 +265,7 @@ def kd_loss_triton(
student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K
teacher_logprobs,
mask,
num_items_in_batch=None
num_items_in_batch=None,
):
"""
Wrapper that calls our Triton-based forward+backward for KD.
@@ -274,4 +273,6 @@ def kd_loss_triton(
or inside a separate kernel. This function expects that you've *already*
called gather on student_logits -> shape [B, seq_len, K].
"""
return _KLDivergenceTritonFn.apply(student_logits, teacher_logprobs, mask, num_items_in_batch)
return _KLDivergenceTritonFn.apply(
student_logits, teacher_logprobs, mask, num_items_in_batch
)