triton kernel for top-k logprob kd

This commit is contained in:
Wing Lian
2025-02-24 22:13:26 -05:00
parent 2d5826f544
commit 165088e7c1

View File

@@ -0,0 +1,281 @@
"""
Optimized Triton kernel for KL divergence loss between teacher and student models.
"""
import torch
import triton
import triton.language as tl
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.benchmark import Timer
# Helper function for computing logsumexp
@triton.jit
def logsumexp_kernel(
logits_ptr, output_ptr,
B, S, V, # batch size, seq len, vocab size
stride_b, stride_s, stride_v,
out_stride_b, out_stride_s,
BLOCK_SIZE: tl.constexpr
):
# Program ID
pid = tl.program_id(0)
batch_idx = pid // S
seq_idx = pid % S
# Bounds check
if batch_idx >= B or seq_idx >= S:
return
# Pointers
logits_base = logits_ptr + batch_idx * stride_b + seq_idx * stride_s
# Find maximum for numerical stability
max_val = -float('inf')
for v_offset in range(0, V, BLOCK_SIZE):
v_size = min(BLOCK_SIZE, V - v_offset)
mask = tl.arange(0, BLOCK_SIZE) < v_size
logits_block = tl.load(logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
mask=mask, other=-float('inf'))
max_val = tl.maximum(max_val, tl.max(logits_block, axis=0))
# Compute sum of exp(logit - max_val)
sum_exp = 0.0
for v_offset in range(0, V, BLOCK_SIZE):
v_size = min(BLOCK_SIZE, V - v_offset)
mask = tl.arange(0, BLOCK_SIZE) < v_size
logits_block = tl.load(logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
mask=mask, other=-float('inf'))
sum_exp += tl.sum(tl.exp(logits_block - max_val), axis=0)
# Compute logsumexp
result = max_val + tl.log(sum_exp)
# Store result
tl.store(output_ptr + batch_idx * out_stride_b + seq_idx * out_stride_s, result)
# Triton-accelerated implementation of KL divergence loss for top-k tokens
class TopKKLDivergence(torch.autograd.Function):
@staticmethod
def forward(ctx, student_logits, target_token_ids, target_logprobs, target_mask,
num_items_in_batch=-1, kd_temperature=1.0, top_k_before_softmax=0):
"""
Forward pass for KL divergence loss between top-k logprobs.
"""
# Convert inputs to appropriate types
student_logits = student_logits.float()
target_logprobs = target_logprobs.float()
# Get dimensions
batch_size, student_seq_len, vocab_size = student_logits.shape
_, teacher_seq_len, top_k = target_token_ids.shape
# Slice student logits to match teacher sequence length
student_logits_for_kd = student_logits[:, :teacher_seq_len, :]
if top_k_before_softmax:
# 1. Apply temperature to student logits
if kd_temperature != 1.0:
student_logits_for_kd = student_logits_for_kd / kd_temperature
# 2. Gather student logits for top-k tokens
student_logits_topk = torch.gather(student_logits_for_kd, dim=-1, index=target_token_ids)
# 3. Compute softmax over gathered logits
student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1)
else:
# 1. Apply temperature to student logits
if kd_temperature != 1.0:
student_logits_for_kd = student_logits_for_kd / kd_temperature
# 2. Gather student logits for top-k tokens
student_logits_topk = torch.gather(student_logits_for_kd, dim=-1, index=target_token_ids)
# 3. Compute logsumexp over full vocabulary using Triton
student_lse = torch.empty((batch_size, teacher_seq_len),
dtype=torch.float32, device=student_logits.device)
grid = (batch_size * teacher_seq_len,)
logsumexp_kernel[grid](
student_logits_for_kd.contiguous(), student_lse,
batch_size, teacher_seq_len, vocab_size,
student_logits_for_kd.stride(0), student_logits_for_kd.stride(1), student_logits_for_kd.stride(2),
student_lse.stride(0), student_lse.stride(1),
min(1024, triton.next_power_of_2(vocab_size))
)
# 4. Convert to logprobs
student_logprobs_topk = student_logits_topk - student_lse.unsqueeze(-1)
# Save tensors for backward pass
ctx.save_for_backward(student_logits_for_kd, target_token_ids, target_logprobs, target_mask, student_logprobs_topk)
ctx.kd_temperature = kd_temperature
ctx.top_k_before_softmax = top_k_before_softmax
ctx.num_items_in_batch = num_items_in_batch
# Convert mask to boolean
valid_mask = target_mask.bool()
# Extract valid tokens only
student_logprobs_valid = student_logprobs_topk[valid_mask]
target_logprobs_valid = target_logprobs[valid_mask]
# Convert teacher logprobs to probabilities
teacher_probs_valid = torch.exp(target_logprobs_valid)
# Compute KL divergence loss
token_losses = teacher_probs_valid * (target_logprobs_valid - student_logprobs_valid)
kd_loss = token_losses.sum()
# Apply temperature scaling
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Normalize by number of items or valid tokens
if num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
else:
kd_loss = kd_loss / float(token_losses.size(0))
return kd_loss
@staticmethod
def backward(ctx, grad_output):
"""
Backward pass for KL divergence loss.
"""
student_logits, target_token_ids, target_logprobs, target_mask, student_logprobs = ctx.saved_tensors
kd_temperature = ctx.kd_temperature
top_k_before_softmax = ctx.top_k_before_softmax
num_items_in_batch = ctx.num_items_in_batch
valid_mask = target_mask.bool()
batch_size, seq_len, vocab_size = student_logits.shape
grad_student_logits = torch.zeros_like(student_logits)
# Convert teacher logprobs to probs
teacher_probs = torch.exp(target_logprobs)
# Scale gradient by temperature if needed
scale = (kd_temperature**2) if kd_temperature != 1.0 else 1.0
# Normalize by number of items or valid tokens
if num_items_in_batch > 0:
scale = scale / float(num_items_in_batch)
else:
scale = scale / float(valid_mask.sum().item())
# Apply gradient
scale = scale * grad_output.item()
# Let PyTorch compute the gradients for us
with torch.enable_grad():
student_logits_grad = student_logits.detach().requires_grad_(True)
if top_k_before_softmax:
student_logits_topk = torch.gather(
student_logits_grad / kd_temperature if kd_temperature != 1.0 else student_logits_grad,
dim=-1, index=target_token_ids
)
student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1)
else:
temp_logits = student_logits_grad / kd_temperature if kd_temperature != 1.0 else student_logits_grad
student_logits_topk = torch.gather(temp_logits, dim=-1, index=target_token_ids)
student_lse = torch.logsumexp(temp_logits, dim=-1, keepdim=True)
student_logprobs_topk = student_logits_topk - student_lse
# Extract valid tokens only
student_logprobs_valid = student_logprobs_topk[valid_mask]
target_logprobs_valid = target_logprobs[valid_mask]
teacher_probs_valid = torch.exp(target_logprobs_valid)
# Compute KL divergence loss
token_losses = teacher_probs_valid * (target_logprobs_valid - student_logprobs_valid)
kd_loss = token_losses.sum() * scale
# Backward pass
kd_loss.backward()
grad_student_logits = student_logits_grad.grad
return grad_student_logits, None, None, None, None, None, None
# Wrapper function for chunked computation
def kl_div_loss_chunked(student_logits, target_token_ids, target_logprobs, target_mask,
num_items_in_batch=-1, kd_temperature=1.0, top_k_before_softmax=0,
n_chunks=1):
"""
Memory-efficient KL divergence loss computation.
Args:
student_logits: Student logits [B, seq_len, vocab_size]
target_token_ids: Teacher token IDs [B, seq_len, top_k]
target_logprobs: Teacher logprobs [B, seq_len, top_k]
target_mask: Token mask [B, seq_len, top_k]
num_items_in_batch: Number of items for normalization (-1 for auto)
kd_temperature: Temperature for KD
top_k_before_softmax: Flag for softmax application order
n_chunks: Number of chunks to process (for memory efficiency)
"""
batch_size = student_logits.shape[0]
# If n_chunks <= 0, use the entire batch size
if n_chunks <= 0:
n_chunks = batch_size
# Determine the actual number of chunks to use (find largest factor <= n_chunks)
factors = [i for i in range(1, batch_size + 1) if batch_size % i == 0]
actual_chunks = factors[min(len(factors) - 1, max(0, next((i for i, f in enumerate(factors) if f >= n_chunks), len(factors) - 1)))]
# Compute chunk size
chunk_size = batch_size // actual_chunks
total_loss = 0.0
# Process in chunks
for i in range(0, batch_size, chunk_size):
chunk_end = min(i + chunk_size, batch_size)
chunk_loss = TopKKLDivergence.apply(
student_logits[i:chunk_end],
target_token_ids[i:chunk_end],
target_logprobs[i:chunk_end],
target_mask[i:chunk_end],
-1 if num_items_in_batch <= 0 else num_items_in_batch // actual_chunks,
kd_temperature,
top_k_before_softmax
)
total_loss += chunk_loss
# Normalize by the number of chunks
return total_loss / actual_chunks
def loss(
student_logits: torch.Tensor,
target_token_ids: torch.Tensor,
target_logprobs: torch.Tensor,
target_mask: torch.Tensor,
num_items_in_batch: int = -1,
kd_temperature: float = 1.0,
top_k_before_softmax: int = 0,
n_chunks: int = 1
) -> torch.Tensor:
"""
Triton-accelerated KL divergence loss for knowledge distillation.
Args:
student_logits: Student model logits [B, seq_len, vocab_size]
target_token_ids: Teacher's top-k token IDs [B, seq_len, top_k]
target_logprobs: Teacher's top-k logprobs [B, seq_len, top_k]
target_mask: Mask for valid tokens [B, seq_len, top_k]
num_items_in_batch: Number of items for normalization (-1 for auto)
kd_temperature: Temperature for KD
top_k_before_softmax: Flag for softmax application order
n_chunks: Number of chunks for memory efficiency
"""
return kl_div_loss_chunked(
student_logits, target_token_ids, target_logprobs, target_mask,
num_items_in_batch, kd_temperature, top_k_before_softmax,
n_chunks
)