no log etc
This commit is contained in:
@@ -2,7 +2,6 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Triton Kernel for forward pass
|
||||
# --------------------------------------------------------
|
||||
@@ -30,21 +29,21 @@ import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def kd_forward_kernel(
|
||||
student_logits_ptr, # float32[B, seq_len, K] after gather
|
||||
teacher_logprobs_ptr, # float32[B, seq_len, K]
|
||||
mask_ptr, # bool[B, seq_len, K] or int8
|
||||
partial_kd_ptr, # float32[B, seq_len] (accumulator)
|
||||
B, # total batch size
|
||||
seq_len, # total sequence length from teacher
|
||||
K, # top-K from teacher
|
||||
BLOCK_SIZE: tl.constexpr # how many tokens per block in dimension0
|
||||
student_logits_ptr, # float32[B, seq_len, K] after gather
|
||||
teacher_logprobs_ptr, # float32[B, seq_len, K]
|
||||
mask_ptr, # bool[B, seq_len, K] or int8
|
||||
partial_kd_ptr, # float32[B, seq_len] (accumulator)
|
||||
B, # total batch size
|
||||
seq_len, # total sequence length from teacher
|
||||
K, # top-K from teacher
|
||||
BLOCK_SIZE: tl.constexpr, # how many tokens per block in dimension0
|
||||
):
|
||||
# program_id is the global index for each block
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# Each block handles a range of seq positions in [0..B*seq_len)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
block_end = tl.min((pid+1)*BLOCK_SIZE, B * seq_len)
|
||||
block_end = tl.min((pid + 1) * BLOCK_SIZE, B * seq_len)
|
||||
length = block_end - block_start
|
||||
|
||||
# Offsets for indexing
|
||||
@@ -69,7 +68,7 @@ def kd_forward_kernel(
|
||||
|
||||
# For K top tokens, read the relevant student logits and teacher logprobs
|
||||
# We'll load them in a small loop:
|
||||
logsumexp_val = float('-inf')
|
||||
logsumexp_val = float("-inf")
|
||||
# We'll store them in a local array for a second pass
|
||||
student_logits_k = [0.0 for _ in range(K)]
|
||||
teacher_logprobs_k = [0.0 for _ in range(K)]
|
||||
@@ -79,26 +78,17 @@ def kd_forward_kernel(
|
||||
for k in range(K):
|
||||
# load student logit
|
||||
student_val = tl.load(
|
||||
student_logits_ptr
|
||||
+ b_idx*seq_len*K
|
||||
+ s_idx*K
|
||||
+ k,
|
||||
mask=(b_idx < B) and (s_idx < seq_len)
|
||||
student_logits_ptr + b_idx * seq_len * K + s_idx * K + k,
|
||||
mask=(b_idx < B) and (s_idx < seq_len),
|
||||
)
|
||||
teacher_val = tl.load(
|
||||
teacher_logprobs_ptr
|
||||
+ b_idx*seq_len*K
|
||||
+ s_idx*K
|
||||
+ k,
|
||||
mask=(b_idx < B) and (s_idx < seq_len)
|
||||
teacher_logprobs_ptr + b_idx * seq_len * K + s_idx * K + k,
|
||||
mask=(b_idx < B) and (s_idx < seq_len),
|
||||
)
|
||||
# get mask
|
||||
mask_val = tl.load(
|
||||
mask_ptr
|
||||
+ b_idx*seq_len*K
|
||||
+ s_idx*K
|
||||
+ k,
|
||||
mask=(b_idx < B) and (s_idx < seq_len)
|
||||
mask_ptr + b_idx * seq_len * K + s_idx * K + k,
|
||||
mask=(b_idx < B) and (s_idx < seq_len),
|
||||
)
|
||||
|
||||
student_logits_k[k] = student_val
|
||||
@@ -119,7 +109,7 @@ def kd_forward_kernel(
|
||||
# safe check
|
||||
if exp_sum == 0.0:
|
||||
exp_sum = 1e-8
|
||||
logsumexp_val = logsumexp_val + float(torch.log(torch.tensor(exp_sum)))
|
||||
logsumexp_val = logsumexp_val + tl.log(torch.tensor(exp_sum))
|
||||
|
||||
# compute partial kl
|
||||
# sum_{k in valid} p^T_k (log p^T_k - log p^S_k)
|
||||
@@ -155,7 +145,7 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
|
||||
BLOCK_SIZE = 1024
|
||||
# compute how many blocks we need
|
||||
total_positions = B * seq_len
|
||||
grid = ( (total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE , )
|
||||
grid = ((total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE,)
|
||||
|
||||
partial_kd = torch.empty(
|
||||
grid[0], dtype=student_logits.dtype, device=student_logits.device
|
||||
@@ -167,8 +157,10 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
|
||||
teacher_logprobs,
|
||||
mask,
|
||||
partial_kd,
|
||||
B, seq_len, K,
|
||||
BLOCK_SIZE=BLOCK_SIZE
|
||||
B,
|
||||
seq_len,
|
||||
K,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Sum partials on CPU or GPU
|
||||
@@ -186,7 +178,12 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
|
||||
# Save context for backward
|
||||
# Typically, you'd need to save the raw student_logits, teacher_logprobs, etc. for grad
|
||||
# But be mindful of memory usage. We’ll demonstrate the minimal approach here:
|
||||
ctx.save_for_backward(student_logits, teacher_logprobs, mask, torch.tensor(num_items_in_batch or 0))
|
||||
ctx.save_for_backward(
|
||||
student_logits,
|
||||
teacher_logprobs,
|
||||
mask,
|
||||
torch.tensor(num_items_in_batch or 0),
|
||||
)
|
||||
ctx.B = B
|
||||
ctx.seq_len = seq_len
|
||||
ctx.K = K
|
||||
@@ -238,7 +235,9 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
|
||||
# treat student_logits as if it requires grad
|
||||
stl = student_logits.clone().detach().requires_grad_(True)
|
||||
# compute logsumexp along K
|
||||
logsumexp_val = torch.logsumexp(stl, dim=-1, keepdim=True) # [B, seq_len, 1]
|
||||
logsumexp_val = torch.logsumexp(
|
||||
stl, dim=-1, keepdim=True
|
||||
) # [B, seq_len, 1]
|
||||
student_logprobs_topk = stl - logsumexp_val
|
||||
teacher_probs = teacher_logprobs.exp()
|
||||
# p^S_k
|
||||
@@ -266,7 +265,7 @@ def kd_loss_triton(
|
||||
student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K
|
||||
teacher_logprobs,
|
||||
mask,
|
||||
num_items_in_batch=None
|
||||
num_items_in_batch=None,
|
||||
):
|
||||
"""
|
||||
Wrapper that calls our Triton-based forward+backward for KD.
|
||||
@@ -274,4 +273,6 @@ def kd_loss_triton(
|
||||
or inside a separate kernel. This function expects that you've *already*
|
||||
called gather on student_logits -> shape [B, seq_len, K].
|
||||
"""
|
||||
return _KLDivergenceTritonFn.apply(student_logits, teacher_logprobs, mask, num_items_in_batch)
|
||||
return _KLDivergenceTritonFn.apply(
|
||||
student_logits, teacher_logprobs, mask, num_items_in_batch
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user