no log etc
This commit is contained in:
@@ -2,7 +2,6 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
# Triton Kernel for forward pass
|
# Triton Kernel for forward pass
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
@@ -30,21 +29,21 @@ import triton.language as tl
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kd_forward_kernel(
|
def kd_forward_kernel(
|
||||||
student_logits_ptr, # float32[B, seq_len, K] after gather
|
student_logits_ptr, # float32[B, seq_len, K] after gather
|
||||||
teacher_logprobs_ptr, # float32[B, seq_len, K]
|
teacher_logprobs_ptr, # float32[B, seq_len, K]
|
||||||
mask_ptr, # bool[B, seq_len, K] or int8
|
mask_ptr, # bool[B, seq_len, K] or int8
|
||||||
partial_kd_ptr, # float32[B, seq_len] (accumulator)
|
partial_kd_ptr, # float32[B, seq_len] (accumulator)
|
||||||
B, # total batch size
|
B, # total batch size
|
||||||
seq_len, # total sequence length from teacher
|
seq_len, # total sequence length from teacher
|
||||||
K, # top-K from teacher
|
K, # top-K from teacher
|
||||||
BLOCK_SIZE: tl.constexpr # how many tokens per block in dimension0
|
BLOCK_SIZE: tl.constexpr, # how many tokens per block in dimension0
|
||||||
):
|
):
|
||||||
# program_id is the global index for each block
|
# program_id is the global index for each block
|
||||||
pid = tl.program_id(0)
|
pid = tl.program_id(0)
|
||||||
|
|
||||||
# Each block handles a range of seq positions in [0..B*seq_len)
|
# Each block handles a range of seq positions in [0..B*seq_len)
|
||||||
block_start = pid * BLOCK_SIZE
|
block_start = pid * BLOCK_SIZE
|
||||||
block_end = tl.min((pid+1)*BLOCK_SIZE, B * seq_len)
|
block_end = tl.min((pid + 1) * BLOCK_SIZE, B * seq_len)
|
||||||
length = block_end - block_start
|
length = block_end - block_start
|
||||||
|
|
||||||
# Offsets for indexing
|
# Offsets for indexing
|
||||||
@@ -69,7 +68,7 @@ def kd_forward_kernel(
|
|||||||
|
|
||||||
# For K top tokens, read the relevant student logits and teacher logprobs
|
# For K top tokens, read the relevant student logits and teacher logprobs
|
||||||
# We'll load them in a small loop:
|
# We'll load them in a small loop:
|
||||||
logsumexp_val = float('-inf')
|
logsumexp_val = float("-inf")
|
||||||
# We'll store them in a local array for a second pass
|
# We'll store them in a local array for a second pass
|
||||||
student_logits_k = [0.0 for _ in range(K)]
|
student_logits_k = [0.0 for _ in range(K)]
|
||||||
teacher_logprobs_k = [0.0 for _ in range(K)]
|
teacher_logprobs_k = [0.0 for _ in range(K)]
|
||||||
@@ -79,26 +78,17 @@ def kd_forward_kernel(
|
|||||||
for k in range(K):
|
for k in range(K):
|
||||||
# load student logit
|
# load student logit
|
||||||
student_val = tl.load(
|
student_val = tl.load(
|
||||||
student_logits_ptr
|
student_logits_ptr + b_idx * seq_len * K + s_idx * K + k,
|
||||||
+ b_idx*seq_len*K
|
mask=(b_idx < B) and (s_idx < seq_len),
|
||||||
+ s_idx*K
|
|
||||||
+ k,
|
|
||||||
mask=(b_idx < B) and (s_idx < seq_len)
|
|
||||||
)
|
)
|
||||||
teacher_val = tl.load(
|
teacher_val = tl.load(
|
||||||
teacher_logprobs_ptr
|
teacher_logprobs_ptr + b_idx * seq_len * K + s_idx * K + k,
|
||||||
+ b_idx*seq_len*K
|
mask=(b_idx < B) and (s_idx < seq_len),
|
||||||
+ s_idx*K
|
|
||||||
+ k,
|
|
||||||
mask=(b_idx < B) and (s_idx < seq_len)
|
|
||||||
)
|
)
|
||||||
# get mask
|
# get mask
|
||||||
mask_val = tl.load(
|
mask_val = tl.load(
|
||||||
mask_ptr
|
mask_ptr + b_idx * seq_len * K + s_idx * K + k,
|
||||||
+ b_idx*seq_len*K
|
mask=(b_idx < B) and (s_idx < seq_len),
|
||||||
+ s_idx*K
|
|
||||||
+ k,
|
|
||||||
mask=(b_idx < B) and (s_idx < seq_len)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
student_logits_k[k] = student_val
|
student_logits_k[k] = student_val
|
||||||
@@ -119,7 +109,7 @@ def kd_forward_kernel(
|
|||||||
# safe check
|
# safe check
|
||||||
if exp_sum == 0.0:
|
if exp_sum == 0.0:
|
||||||
exp_sum = 1e-8
|
exp_sum = 1e-8
|
||||||
logsumexp_val = logsumexp_val + float(torch.log(torch.tensor(exp_sum)))
|
logsumexp_val = logsumexp_val + tl.log(torch.tensor(exp_sum))
|
||||||
|
|
||||||
# compute partial kl
|
# compute partial kl
|
||||||
# sum_{k in valid} p^T_k (log p^T_k - log p^S_k)
|
# sum_{k in valid} p^T_k (log p^T_k - log p^S_k)
|
||||||
@@ -155,7 +145,7 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
|
|||||||
BLOCK_SIZE = 1024
|
BLOCK_SIZE = 1024
|
||||||
# compute how many blocks we need
|
# compute how many blocks we need
|
||||||
total_positions = B * seq_len
|
total_positions = B * seq_len
|
||||||
grid = ( (total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE , )
|
grid = ((total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE,)
|
||||||
|
|
||||||
partial_kd = torch.empty(
|
partial_kd = torch.empty(
|
||||||
grid[0], dtype=student_logits.dtype, device=student_logits.device
|
grid[0], dtype=student_logits.dtype, device=student_logits.device
|
||||||
@@ -167,8 +157,10 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
|
|||||||
teacher_logprobs,
|
teacher_logprobs,
|
||||||
mask,
|
mask,
|
||||||
partial_kd,
|
partial_kd,
|
||||||
B, seq_len, K,
|
B,
|
||||||
BLOCK_SIZE=BLOCK_SIZE
|
seq_len,
|
||||||
|
K,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sum partials on CPU or GPU
|
# Sum partials on CPU or GPU
|
||||||
@@ -186,7 +178,12 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
|
|||||||
# Save context for backward
|
# Save context for backward
|
||||||
# Typically, you'd need to save the raw student_logits, teacher_logprobs, etc. for grad
|
# Typically, you'd need to save the raw student_logits, teacher_logprobs, etc. for grad
|
||||||
# But be mindful of memory usage. We’ll demonstrate the minimal approach here:
|
# But be mindful of memory usage. We’ll demonstrate the minimal approach here:
|
||||||
ctx.save_for_backward(student_logits, teacher_logprobs, mask, torch.tensor(num_items_in_batch or 0))
|
ctx.save_for_backward(
|
||||||
|
student_logits,
|
||||||
|
teacher_logprobs,
|
||||||
|
mask,
|
||||||
|
torch.tensor(num_items_in_batch or 0),
|
||||||
|
)
|
||||||
ctx.B = B
|
ctx.B = B
|
||||||
ctx.seq_len = seq_len
|
ctx.seq_len = seq_len
|
||||||
ctx.K = K
|
ctx.K = K
|
||||||
@@ -238,7 +235,9 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
|
|||||||
# treat student_logits as if it requires grad
|
# treat student_logits as if it requires grad
|
||||||
stl = student_logits.clone().detach().requires_grad_(True)
|
stl = student_logits.clone().detach().requires_grad_(True)
|
||||||
# compute logsumexp along K
|
# compute logsumexp along K
|
||||||
logsumexp_val = torch.logsumexp(stl, dim=-1, keepdim=True) # [B, seq_len, 1]
|
logsumexp_val = torch.logsumexp(
|
||||||
|
stl, dim=-1, keepdim=True
|
||||||
|
) # [B, seq_len, 1]
|
||||||
student_logprobs_topk = stl - logsumexp_val
|
student_logprobs_topk = stl - logsumexp_val
|
||||||
teacher_probs = teacher_logprobs.exp()
|
teacher_probs = teacher_logprobs.exp()
|
||||||
# p^S_k
|
# p^S_k
|
||||||
@@ -266,7 +265,7 @@ def kd_loss_triton(
|
|||||||
student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K
|
student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K
|
||||||
teacher_logprobs,
|
teacher_logprobs,
|
||||||
mask,
|
mask,
|
||||||
num_items_in_batch=None
|
num_items_in_batch=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Wrapper that calls our Triton-based forward+backward for KD.
|
Wrapper that calls our Triton-based forward+backward for KD.
|
||||||
@@ -274,4 +273,6 @@ def kd_loss_triton(
|
|||||||
or inside a separate kernel. This function expects that you've *already*
|
or inside a separate kernel. This function expects that you've *already*
|
||||||
called gather on student_logits -> shape [B, seq_len, K].
|
called gather on student_logits -> shape [B, seq_len, K].
|
||||||
"""
|
"""
|
||||||
return _KLDivergenceTritonFn.apply(student_logits, teacher_logprobs, mask, num_items_in_batch)
|
return _KLDivergenceTritonFn.apply(
|
||||||
|
student_logits, teacher_logprobs, mask, num_items_in_batch
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user