v3
This commit is contained in:
@@ -108,13 +108,21 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
)
|
||||
|
||||
# Now call the Triton-based KD loss
|
||||
loss_kd = kd_loss_triton(
|
||||
kd_sum = kd_loss_triton(
|
||||
student_logits_topk,
|
||||
target_logprobs, # teacher logprobs [B, seq_len, K]
|
||||
target_mask, # mask [B, seq_len, K]
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
# Normalize however you want
|
||||
if num_items_in_batch is not None:
|
||||
loss_kd = kd_sum / num_items_in_batch
|
||||
else:
|
||||
# or do e.g. average over valid tokens
|
||||
# quick example:
|
||||
total_valid = target_mask.sum()
|
||||
loss_kd = kd_sum / (total_valid + 1e-8)
|
||||
|
||||
# optionally combine with CE loss
|
||||
if self.args.kd_ce_alpha > 0:
|
||||
loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd
|
||||
|
||||
@@ -29,236 +29,215 @@ import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def kd_forward_kernel(
|
||||
student_logits_ptr, # float32[B, seq_len, K] after gather
|
||||
teacher_logprobs_ptr, # float32[B, seq_len, K]
|
||||
mask_ptr, # bool[B, seq_len, K] or int8
|
||||
partial_kd_ptr, # float32[B, seq_len] (accumulator)
|
||||
B, # total batch size
|
||||
seq_len, # total sequence length from teacher
|
||||
K, # top-K from teacher
|
||||
BLOCK_SIZE: tl.constexpr, # how many tokens per block in dimension0
|
||||
# student_logits after gather: [B, seq_len, K] flattened to 1D in row-major
|
||||
student_logits_ptr: tl.tensor,
|
||||
# teacher_logprobs: [B, seq_len, K] flattened
|
||||
teacher_logprobs_ptr: tl.tensor,
|
||||
# mask: [B, seq_len, K] flattened (bool or 0/1)
|
||||
mask_ptr: tl.tensor,
|
||||
# partial_kd: [B*seq_len] flattened buffer to store partial sums
|
||||
partial_kd_ptr: tl.tensor,
|
||||
B: tl.int32,
|
||||
seq_len: tl.int32,
|
||||
K: tl.int32,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# program_id is the global index for each block
|
||||
"""
|
||||
For each position in [0..B*seq_len), we:
|
||||
- gather the K student logits
|
||||
- compute logsumexp
|
||||
- compute the KL sum = sum_{k} t_prob_k * ( t_log_k - s_logprob_k )
|
||||
- store that partial sum into partial_kd_ptr[offset].
|
||||
"""
|
||||
# 1) Identify which [B*seq_len] index this block handles
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# Each block handles a range of seq positions in [0..B*seq_len)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
block_end = tl.min((pid + 1) * BLOCK_SIZE, B * seq_len)
|
||||
length = block_end - block_start
|
||||
# 2) Vector of [0..BLOCK_SIZE) local offsets
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
# 3) Global indices = pid * BLOCK_SIZE + offsets
|
||||
idx = pid * BLOCK_SIZE + offsets
|
||||
|
||||
# Offsets for indexing
|
||||
# We want to interpret a linear index in [0..B*seq_len) as (batch_idx, seq_idx)
|
||||
# E.g.:
|
||||
# batch_idx = block_start // seq_len
|
||||
# seq_idx = block_start % seq_len
|
||||
# but we must do this for each element in the block. We'll do that inside a loop.
|
||||
# 4) Mask to ensure we don’t read out-of-bounds
|
||||
total_positions = B * seq_len
|
||||
mask_pos = idx < total_positions
|
||||
|
||||
# We'll store a running partial KL sum in registers
|
||||
# We do a for-loop for each position in the block, then do a thread-level reduction
|
||||
kd_reg = 0.0
|
||||
# 5) Convert a 1D `idx` => (b_idx, s_idx)
|
||||
# b_idx is the batch number, s_idx is the sequence position
|
||||
b_idx = idx // seq_len
|
||||
s_idx = idx % seq_len
|
||||
|
||||
# We'll iterate over each item in [block_start, block_end).
|
||||
# A more advanced approach can use vectorization / warp-based parallelism inside the block.
|
||||
for offset in range(length):
|
||||
# Convert offset -> actual index in [0..B*seq_len)
|
||||
linear_idx = block_start + offset
|
||||
# batch index and sequence index
|
||||
b_idx = linear_idx // seq_len
|
||||
s_idx = linear_idx % seq_len
|
||||
# We'll accumulate the KL for each index in a register array
|
||||
kl_acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
|
||||
# For K top tokens, read the relevant student logits and teacher logprobs
|
||||
# We'll load them in a small loop:
|
||||
logsumexp_val = float("-inf")
|
||||
# We'll store them in a local array for a second pass
|
||||
student_logits_k = [0.0 for _ in range(K)]
|
||||
teacher_logprobs_k = [0.0 for _ in range(K)]
|
||||
valid_k = [0 for _ in range(K)]
|
||||
# -------------------------------------------------------------------------
|
||||
# First pass: find max logits over K to implement logsumexp
|
||||
# -------------------------------------------------------------------------
|
||||
max_val = tl.full([BLOCK_SIZE], -1e30, dtype=tl.float32)
|
||||
|
||||
# gather the top-K logits & teacher logprobs
|
||||
for k in range(K):
|
||||
# load student logit
|
||||
student_val = tl.load(
|
||||
student_logits_ptr + b_idx * seq_len * K + s_idx * K + k,
|
||||
mask=(b_idx < B) and (s_idx < seq_len),
|
||||
)
|
||||
teacher_val = tl.load(
|
||||
teacher_logprobs_ptr + b_idx * seq_len * K + s_idx * K + k,
|
||||
mask=(b_idx < B) and (s_idx < seq_len),
|
||||
)
|
||||
# get mask
|
||||
mask_val = tl.load(
|
||||
mask_ptr + b_idx * seq_len * K + s_idx * K + k,
|
||||
mask=(b_idx < B) and (s_idx < seq_len),
|
||||
)
|
||||
# Python-level loops are allowed in Triton as long as the
|
||||
# operations inside are Triton ops, not torch or Python math.
|
||||
for k in range(K):
|
||||
# pointer offset in the flattened [B, seq_len, K] = b_idx*(seq_len*K) + s_idx*K + k
|
||||
offset_k = b_idx * (seq_len * K) + s_idx * K + k
|
||||
|
||||
student_logits_k[k] = student_val
|
||||
teacher_logprobs_k[k] = teacher_val
|
||||
valid_k[k] = mask_val
|
||||
# load student logits, masked out-of-bounds with a large negative
|
||||
# so they don't affect the max
|
||||
student_val = tl.where(
|
||||
mask_pos,
|
||||
tl.load(student_logits_ptr + offset_k),
|
||||
-1e30
|
||||
)
|
||||
# update running max
|
||||
max_val = tl.where(student_val > max_val, student_val, max_val)
|
||||
|
||||
# track max for logsumexp (naive approach)
|
||||
if student_val > logsumexp_val:
|
||||
logsumexp_val = student_val
|
||||
# -------------------------------------------------------------------------
|
||||
# Second pass: sum of exp(...) to complete logsumexp
|
||||
# -------------------------------------------------------------------------
|
||||
exp_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for k in range(K):
|
||||
offset_k = b_idx * (seq_len * K) + s_idx * K + k
|
||||
student_val = tl.where(
|
||||
mask_pos,
|
||||
tl.load(student_logits_ptr + offset_k),
|
||||
-1e30
|
||||
)
|
||||
# exponent
|
||||
exponent = tl.exp(student_val - max_val)
|
||||
exp_sum += exponent
|
||||
|
||||
# now compute logsumexp for the K student logits
|
||||
# logsumexp = max_val + log(sum( exp(student_val - max_val) ))
|
||||
exp_sum = 0.0
|
||||
for k in range(K):
|
||||
if valid_k[k] != 0: # if valid
|
||||
exp_val = tl.exp(student_logits_k[k] - logsumexp_val)
|
||||
exp_sum += exp_val
|
||||
# safe check
|
||||
epsilon = 1e-8 # Small constant to prevent log(0)
|
||||
exp_sum = tl.where(exp_sum == 0.0, epsilon, exp_sum)
|
||||
logsumexp_val = logsumexp_val + tl.log(exp_sum)
|
||||
# final logsumexp
|
||||
logsumexp_val = max_val + tl.log(exp_sum)
|
||||
|
||||
# compute partial kl
|
||||
# sum_{k in valid} p^T_k (log p^T_k - log p^S_k)
|
||||
# teacher_probs_k = exp(teacher_logprobs_k)
|
||||
for k in range(K):
|
||||
if valid_k[k] != 0: # only valid tokens
|
||||
teacher_prob = tl.exp(teacher_logprobs_k[k])
|
||||
student_logprob = student_logits_k[k] - logsumexp_val
|
||||
kd_val = teacher_prob * (teacher_logprobs_k[k] - student_logprob)
|
||||
kd_reg += kd_val
|
||||
# -------------------------------------------------------------------------
|
||||
# Third pass: compute partial KL per position
|
||||
# KL = sum_{k in valid} p^T_k * (teacher_logprobs_k - student_logprobs_k)
|
||||
#
|
||||
# - teacher_logprobs_k => t_log
|
||||
# - teacher_prob_k = exp(t_log)
|
||||
# - student_logprobs_k = s_val - logsumexp_val
|
||||
# -------------------------------------------------------------------------
|
||||
for k in range(K):
|
||||
offset_k = b_idx * (seq_len * K) + s_idx * K + k
|
||||
# teacher logprobs
|
||||
t_log = tl.where(
|
||||
mask_pos,
|
||||
tl.load(teacher_logprobs_ptr + offset_k),
|
||||
-1e30
|
||||
)
|
||||
# teacher prob
|
||||
t_prob = tl.exp(t_log)
|
||||
|
||||
# Write out partial kd for this block. We store a single partial sum in partial_kd_ptr
|
||||
# We'll store it at partial_kd_ptr[pid]
|
||||
# In real code, you might do an atomic add into partial_kd_ptr or a parallel reduction pass
|
||||
# for now, let's just store it at index=pid
|
||||
tl.store(partial_kd_ptr + pid, kd_reg)
|
||||
# student logit
|
||||
s_val = tl.where(
|
||||
mask_pos,
|
||||
tl.load(student_logits_ptr + offset_k),
|
||||
-1e30
|
||||
)
|
||||
# student logprob
|
||||
s_logprob = s_val - logsumexp_val
|
||||
|
||||
# local KL
|
||||
kl_val = t_prob * (t_log - s_logprob)
|
||||
|
||||
# also read mask to disable invalid tokens if mask is not purely sequence-based
|
||||
valid_k = tl.load(mask_ptr + offset_k)
|
||||
# if mask is bool => use 'valid_k != 0', if it's 0/1 => same
|
||||
is_valid = (valid_k > 0)
|
||||
|
||||
# zero out if either this index is out-of-bounds or mask is invalid
|
||||
kl_val = tl.where(mask_pos & is_valid, kl_val, 0.0)
|
||||
|
||||
# accumulate
|
||||
kl_acc += kl_val
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Store the partial KL in partial_kd_ptr for each element in idx.
|
||||
# Later in Python, you can do partial_kd.sum() to get the total KL.
|
||||
# -------------------------------------------------------------------------
|
||||
tl.store(partial_kd_ptr + idx, kl_acc, mask=mask_pos)
|
||||
|
||||
|
||||
def kd_forward_pass_triton(
|
||||
student_logits, # [B, seq_len, K] (already gathered)
|
||||
teacher_logprobs, # [B, seq_len, K]
|
||||
mask, # [B, seq_len, K] bool or 0/1
|
||||
BLOCK_SIZE=1024,
|
||||
):
|
||||
"""
|
||||
Returns total KL (float). We do the sum on the Python side.
|
||||
NOTE: No normalization is done here.
|
||||
You might divide by `num_items_in_batch` or # valid tokens afterward.
|
||||
"""
|
||||
B, seq_len, K = student_logits.shape
|
||||
# Flatten
|
||||
student_logits_flat = student_logits.reshape(-1)
|
||||
teacher_logprobs_flat = teacher_logprobs.reshape(-1)
|
||||
mask_flat = mask.reshape(-1)
|
||||
|
||||
total_positions = B * seq_len
|
||||
# We'll store partial KL sums for each of the B*seq_len positions
|
||||
partial_kd = torch.empty(
|
||||
total_positions, dtype=student_logits.dtype, device=student_logits.device
|
||||
)
|
||||
|
||||
# Grid config
|
||||
grid = ((total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE,)
|
||||
|
||||
kd_forward_kernel[grid](
|
||||
student_logits_flat,
|
||||
teacher_logprobs_flat,
|
||||
mask_flat,
|
||||
partial_kd,
|
||||
B, seq_len, K,
|
||||
BLOCK_SIZE=BLOCK_SIZE
|
||||
)
|
||||
|
||||
# Sum on CPU or GPU
|
||||
kd_sum = partial_kd.sum()
|
||||
return kd_sum
|
||||
|
||||
class _KLDivergenceTritonFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, student_logits, teacher_logprobs, mask, num_items_in_batch):
|
||||
def forward(ctx, student_logits, teacher_logprobs, mask):
|
||||
"""
|
||||
Inputs shape assumptions (after gather!):
|
||||
- student_logits: [B, seq_len, K]
|
||||
- teacher_logprobs: [B, seq_len, K]
|
||||
- mask: [B, seq_len, K] (bool or 0/1) for valid tokens
|
||||
student_logits: (B, seq_len, K)
|
||||
teacher_logprobs: (B, seq_len, K)
|
||||
mask: (B, seq_len, K)
|
||||
"""
|
||||
B, seq_len, K = student_logits.shape
|
||||
|
||||
# Prepare output buffer for partial sums
|
||||
# We'll have BLOCK_SIZE define how many (batch*seq_len) items each block processes
|
||||
# For simplicity, let's aim for one block per 1024 positions
|
||||
BLOCK_SIZE = 1024
|
||||
# compute how many blocks we need
|
||||
total_positions = B * seq_len
|
||||
grid = ((total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE,)
|
||||
|
||||
partial_kd = torch.empty(
|
||||
grid[0], dtype=student_logits.dtype, device=student_logits.device
|
||||
)
|
||||
|
||||
# Launch kernel
|
||||
kd_forward_kernel[grid](
|
||||
student_logits,
|
||||
teacher_logprobs,
|
||||
mask,
|
||||
partial_kd,
|
||||
B,
|
||||
seq_len,
|
||||
K,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Sum partials on CPU or GPU
|
||||
kd_sum = partial_kd.sum()
|
||||
|
||||
# normalize
|
||||
if num_items_in_batch is not None:
|
||||
kd_loss = kd_sum / num_items_in_batch
|
||||
else:
|
||||
# Just average over all valid tokens; in practice you'd need the count of valid tokens
|
||||
# For a quick approximation, let's do kd_sum / total_positions (or do a separate reduction on mask)
|
||||
# This is a simplification. For correctness, you should count valid tokens in the kernel.
|
||||
kd_loss = kd_sum / (total_positions * K)
|
||||
|
||||
# Save context for backward
|
||||
# Typically, you'd need to save the raw student_logits, teacher_logprobs, etc. for grad
|
||||
# But be mindful of memory usage. We’ll demonstrate the minimal approach here:
|
||||
ctx.save_for_backward(
|
||||
student_logits,
|
||||
teacher_logprobs,
|
||||
mask,
|
||||
torch.tensor(num_items_in_batch or 0),
|
||||
)
|
||||
ctx.B = B
|
||||
ctx.seq_len = seq_len
|
||||
ctx.K = K
|
||||
ctx.total_positions = total_positions
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
kd_sum = kd_forward_pass_triton(student_logits, teacher_logprobs, mask)
|
||||
kd_loss = kd_sum # Not normalized here. You can do that externally.
|
||||
|
||||
# Save for backward
|
||||
ctx.save_for_backward(student_logits, teacher_logprobs, mask)
|
||||
return kd_loss
|
||||
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
grad_output is dLoss/dOut (a scalar).
|
||||
We want dLoss/dStudentLogits.
|
||||
Recall that:
|
||||
|
||||
Loss = sum_{valid k} p^T_k ( log p^T_k - (student_logits_k - logsumexp(student_logits_all_k)) )
|
||||
= sum_{valid k} p^T_k log p^T_k - sum_{valid k} p^T_k student_logits_k + sum_{valid k} p^T_k logsumexp(...)
|
||||
|
||||
Let’s break down the derivative wrt student_logits_k. More precisely, from:
|
||||
d/d student_logits_k [ - p^T_k student_logprobs_k ]
|
||||
you get:
|
||||
- p^T_k * ( d/d student_logits_k [ student_logits_k - logsumexp(...) ] )
|
||||
= - p^T_k * (1 - p^S_k)
|
||||
= p^T_k * p^S_k - p^T_k
|
||||
= p^S_k * p^T_k - p^T_k
|
||||
= p^T_k( p^S_k - 1 )
|
||||
|
||||
In practice, we also must handle the mask.
|
||||
A real implementation typically re-runs the gather & logsumexp calculations or caches them in forward().
|
||||
For brevity, we do a naive approach in PyTorch (not Triton) for the backward.
|
||||
For maximum speed, you'd do a second Triton kernel.
|
||||
|
||||
We'll do a minimal approach here: recompute everything on the host side or a pure PyTorch pass.
|
||||
"""
|
||||
student_logits, teacher_logprobs, mask, num_items_in_batch_t = ctx.saved_tensors
|
||||
num_items_in_batch = int(num_items_in_batch_t.item())
|
||||
B, seq_len, K = ctx.B, ctx.seq_len, ctx.K
|
||||
|
||||
# We can either replicate the entire forward logic in PyTorch for gradient
|
||||
# or do a second Triton pass. Here, let's do it in PyTorch for clarity.
|
||||
|
||||
# 1) compute logsumexp of student_logits_k for each [b, s]
|
||||
# 2) compute p^S_k
|
||||
# 3) compute p^T_k from teacher_logprobs
|
||||
# 4) dLoss/dStudentLogits = grad_output * p^T_k ( p^S_k - 1 ), masked
|
||||
# 5) sum or gather the final gradient
|
||||
# We'll do naive PyTorch re-computation for gradient wrt student_logits
|
||||
student_logits, teacher_logprobs, mask = ctx.saved_tensors
|
||||
# grad_output is dLoss/dOut => a scalar
|
||||
# Let’s compute dLoss/dStudentLogits with the same formula as your original code
|
||||
|
||||
with torch.enable_grad():
|
||||
# treat student_logits as if it requires grad
|
||||
stl = student_logits.clone().detach().requires_grad_(True)
|
||||
# compute logsumexp along K
|
||||
logsumexp_val = torch.logsumexp(
|
||||
stl, dim=-1, keepdim=True
|
||||
) # [B, seq_len, 1]
|
||||
student_logprobs_topk = stl - logsumexp_val
|
||||
teacher_probs = teacher_logprobs.exp()
|
||||
# p^S_k
|
||||
p_s = student_logprobs_topk.exp()
|
||||
t_log = teacher_logprobs
|
||||
# mask might be bool or 0/1
|
||||
# compute logsumexp
|
||||
lse = torch.logsumexp(stl, dim=-1, keepdim=True)
|
||||
s_logprob = stl - lse
|
||||
t_prob = t_log.exp()
|
||||
|
||||
# forward kl = sum p^T_k ( teacher_logprobs_k - student_logprobs_topk )
|
||||
# derivative wrt stl = p^T_k( p^S_k - 1 )
|
||||
grad_stl = teacher_probs * (p_s - 1.0)
|
||||
# respect the mask
|
||||
grad_stl = grad_stl * mask # zero out invalid
|
||||
# forward KL = sum_{k} p^T_k ( t_log_k - s_logprob_k )
|
||||
kl_val = t_prob * (t_log - s_logprob)
|
||||
# mask out
|
||||
kl_val = kl_val * mask # zero out invalid
|
||||
|
||||
# sum or average
|
||||
if num_items_in_batch != 0:
|
||||
grad_stl = grad_stl / num_items_in_batch
|
||||
else:
|
||||
grad_stl = grad_stl / (B * seq_len * K) # fallback
|
||||
kd_loss = kl_val.sum()
|
||||
# now compute dLoss/d stl
|
||||
grad_stl = torch.autograd.grad(kd_loss, stl, grad_output=grad_output)[0]
|
||||
|
||||
# multiply by upstream grad_output
|
||||
grad_stl = grad_stl * grad_output
|
||||
|
||||
return grad_stl, None, None, None
|
||||
return grad_stl, None, None
|
||||
|
||||
|
||||
def kd_loss_triton(
|
||||
@@ -274,5 +253,5 @@ def kd_loss_triton(
|
||||
called gather on student_logits -> shape [B, seq_len, K].
|
||||
"""
|
||||
return _KLDivergenceTritonFn.apply(
|
||||
student_logits, teacher_logprobs, mask, num_items_in_batch
|
||||
student_logits, teacher_logprobs, mask, # num_items_in_batch
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user