more optims

This commit is contained in:
Wing Lian
2025-02-26 01:49:47 -05:00
parent d753ead033
commit afbb44f08b

View File

@@ -6,7 +6,125 @@ import torch
import triton
import triton.language as tl
from .logsumexp import logsumexp_kernel
@triton.jit
def fused_logsumexp_logprobs_kernel(
student_logits_ptr, # Input logits in original dtype
student_logprobs_ptr, # Output logprobs (float32)
token_ids_ptr, # Token IDs for top-k
B,
S,
V,
K, # batch size, seq len, vocab size, top-k
temperature,
stride_l_b,
stride_l_s,
stride_l_v,
stride_lp_b,
stride_lp_s,
stride_lp_k,
stride_t_b,
stride_t_s,
stride_t_k,
BLOCK_SIZE: tl.constexpr,
):
"""
Fused kernel that computes logsumexp and logprobs for topk tokens.
All computations are done in float32 for numerical stability.
"""
# Program ID
pid = tl.program_id(0)
batch_idx = pid // S
seq_idx = pid % S
# Bounds check
if batch_idx >= B or seq_idx >= S:
return
# Compute logsumexp over the vocabulary
max_val = -float("inf")
# Phase 1: Find max value across vocabulary
for v_offset in range(0, V, BLOCK_SIZE):
# Create block indices and mask
block_size = min(BLOCK_SIZE, V - v_offset)
block_idx = tl.arange(0, BLOCK_SIZE)
mask = block_idx < block_size
# Load logits block and convert to float32 in-place
ptrs = (
student_logits_ptr
+ batch_idx * stride_l_b
+ seq_idx * stride_l_s
+ (v_offset + block_idx) * stride_l_v
)
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
# Apply temperature scaling if needed
if temperature != 1.0:
block_logits = block_logits / temperature
# Update max value
block_max = tl.max(block_logits, axis=0)
max_val = tl.maximum(max_val, block_max)
# Phase 2: Compute sum of exp(logits - max_val)
sum_exp = 0.0
for v_offset in range(0, V, BLOCK_SIZE):
# Create block indices and mask
block_size = min(BLOCK_SIZE, V - v_offset)
block_idx = tl.arange(0, BLOCK_SIZE)
mask = block_idx < block_size
# Load logits block and convert to float32 in-place
ptrs = (
student_logits_ptr
+ batch_idx * stride_l_b
+ seq_idx * stride_l_s
+ (v_offset + block_idx) * stride_l_v
)
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
# Apply temperature scaling if needed
if temperature != 1.0:
block_logits = block_logits / temperature
# Compute exp(logits - max_val) and add to sum
block_exp = tl.exp(block_logits - max_val)
sum_exp += tl.sum(block_exp * mask, axis=0)
# Compute final logsumexp
logsumexp = max_val + tl.log(sum_exp)
# Phase 3: Compute and store logprobs for the top-k tokens
token_ids_base = token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
logprobs_base = (
student_logprobs_ptr + batch_idx * stride_lp_b + seq_idx * stride_lp_s
)
for k in range(K):
# Load token ID for position k
token_id = tl.load(token_ids_base + k * stride_t_k)
# Load the corresponding logit and convert to float32
token_logit_ptr = (
student_logits_ptr
+ batch_idx * stride_l_b
+ seq_idx * stride_l_s
+ token_id * stride_l_v
)
token_logit = tl.load(token_logit_ptr).to(tl.float32)
# Apply temperature scaling if needed
if temperature != 1.0:
token_logit = token_logit / temperature
# Compute logprob directly: logit - logsumexp
token_logprob = token_logit - logsumexp
# Store the result
tl.store(logprobs_base + k * stride_lp_k, token_logprob)
@triton.jit
@@ -218,64 +336,61 @@ class TopKKLDivergence(torch.autograd.Function):
"""
Forward pass for KL divergence loss between top-k logprobs.
"""
# Convert inputs to appropriate types
student_logits = student_logits.float()
# Only convert target_logprobs to float, leave student_logits as is
target_logprobs = target_logprobs.float()
# Get dimensions
batch_size, _, vocab_size = student_logits.shape
_, teacher_seq_len, _ = target_token_ids.shape
_, teacher_seq_len, top_k = target_token_ids.shape
# Slice student logits to match teacher sequence length
student_logits_for_kd = student_logits[:, :teacher_seq_len, :]
if top_k_before_softmax:
# 1. Apply temperature to student logits
# Apply temperature to student logits
if kd_temperature != 1.0:
student_logits_for_kd = student_logits_for_kd / kd_temperature
# 2. Gather student logits for top-k tokens
# Gather student logits for top-k tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
)
# 3. Compute softmax over gathered logits
# Compute softmax over gathered logits
student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1)
student_probs_topk = torch.exp(student_logprobs_topk)
else:
# 1. Apply temperature to student logits
if kd_temperature != 1.0:
student_logits_for_kd = student_logits_for_kd / kd_temperature
# 2. Gather student logits for top-k tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
)
# 3. Compute logsumexp over full vocabulary using Triton
student_lse = torch.empty(
(batch_size, teacher_seq_len),
# Allocate output tensor for logprobs directly (always in float32)
student_logprobs_topk = torch.empty(
(batch_size, teacher_seq_len, top_k),
dtype=torch.float32,
device=student_logits.device,
)
# Launch fused kernel directly
grid = (batch_size * teacher_seq_len,)
logsumexp_kernel[grid](
fused_logsumexp_logprobs_kernel[grid](
student_logits_for_kd.contiguous(),
student_lse,
student_logprobs_topk,
target_token_ids.contiguous(),
batch_size,
teacher_seq_len,
vocab_size,
top_k,
kd_temperature,
student_logits_for_kd.stride(0),
student_logits_for_kd.stride(1),
student_logits_for_kd.stride(2),
student_lse.stride(0),
student_lse.stride(1),
student_logprobs_topk.stride(0),
student_logprobs_topk.stride(1),
student_logprobs_topk.stride(2),
target_token_ids.stride(0),
target_token_ids.stride(1),
target_token_ids.stride(2),
min(1024, triton.next_power_of_2(vocab_size)),
)
# 4. Convert to logprobs
student_logprobs_topk = student_logits_topk - student_lse.unsqueeze(-1)
# Calculate probs from logprobs
student_probs_topk = torch.exp(student_logprobs_topk)
# Save tensors for backward pass
@@ -321,7 +436,7 @@ class TopKKLDivergence(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
"""
Optimized backward pass for KL divergence loss.
Optimized backward pass for KL divergence loss with proper dtype handling.
"""
(
student_logits,
@@ -333,11 +448,13 @@ class TopKKLDivergence(torch.autograd.Function):
kd_temperature = ctx.kd_temperature
num_items_in_batch = ctx.num_items_in_batch
# Store original dtype for later conversion
original_dtype = student_logits.dtype
batch_size, seq_len, vocab_size = student_logits.shape
_, _, top_k = target_token_ids.shape
# Initialize gradient tensor
grad_student_logits = torch.zeros_like(student_logits)
# Initialize gradient tensor in float32 to support atomic operations
grad_student_logits = torch.zeros_like(student_logits, dtype=torch.float32)
# Compute scaling factor
scale = grad_output.item()
@@ -353,15 +470,13 @@ class TopKKLDivergence(torch.autograd.Function):
scale = scale / float(target_mask.sum().item())
# Apply chain rule for temperature scaling (1/temperature)
# This comes from d(logits/temperature)/d(logits) = 1/temperature
if kd_temperature != 1.0:
scale = scale / kd_temperature
# Convert teacher logprobs to probabilities
teacher_probs = torch.exp(target_logprobs)
# Depending on which mode was used in forward, we use different gradient calculation
# FIXME: top_k_before_softmax not correctly yet?
# Launch gradient computation kernel
grid = (batch_size * seq_len,)
grad_softmax_kernel[grid](
grad_student_logits.contiguous(),
@@ -392,6 +507,10 @@ class TopKKLDivergence(torch.autograd.Function):
min(1024, triton.next_power_of_2(top_k)),
)
# Convert gradient back to original dtype if needed
if original_dtype != torch.float32:
grad_student_logits = grad_student_logits.to(original_dtype)
# Return gradients for student_logits and None for other inputs
return grad_student_logits, None, None, None, None, None, None