more optims
This commit is contained in:
@@ -6,7 +6,125 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .logsumexp import logsumexp_kernel
|
||||
|
||||
@triton.jit
|
||||
def fused_logsumexp_logprobs_kernel(
|
||||
student_logits_ptr, # Input logits in original dtype
|
||||
student_logprobs_ptr, # Output logprobs (float32)
|
||||
token_ids_ptr, # Token IDs for top-k
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
K, # batch size, seq len, vocab size, top-k
|
||||
temperature,
|
||||
stride_l_b,
|
||||
stride_l_s,
|
||||
stride_l_v,
|
||||
stride_lp_b,
|
||||
stride_lp_s,
|
||||
stride_lp_k,
|
||||
stride_t_b,
|
||||
stride_t_s,
|
||||
stride_t_k,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Fused kernel that computes logsumexp and logprobs for topk tokens.
|
||||
All computations are done in float32 for numerical stability.
|
||||
"""
|
||||
# Program ID
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Compute logsumexp over the vocabulary
|
||||
max_val = -float("inf")
|
||||
|
||||
# Phase 1: Find max value across vocabulary
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
# Create block indices and mask
|
||||
block_size = min(BLOCK_SIZE, V - v_offset)
|
||||
block_idx = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block_idx < block_size
|
||||
|
||||
# Load logits block and convert to float32 in-place
|
||||
ptrs = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ (v_offset + block_idx) * stride_l_v
|
||||
)
|
||||
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
block_logits = block_logits / temperature
|
||||
|
||||
# Update max value
|
||||
block_max = tl.max(block_logits, axis=0)
|
||||
max_val = tl.maximum(max_val, block_max)
|
||||
|
||||
# Phase 2: Compute sum of exp(logits - max_val)
|
||||
sum_exp = 0.0
|
||||
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
# Create block indices and mask
|
||||
block_size = min(BLOCK_SIZE, V - v_offset)
|
||||
block_idx = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block_idx < block_size
|
||||
|
||||
# Load logits block and convert to float32 in-place
|
||||
ptrs = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ (v_offset + block_idx) * stride_l_v
|
||||
)
|
||||
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
block_logits = block_logits / temperature
|
||||
|
||||
# Compute exp(logits - max_val) and add to sum
|
||||
block_exp = tl.exp(block_logits - max_val)
|
||||
sum_exp += tl.sum(block_exp * mask, axis=0)
|
||||
|
||||
# Compute final logsumexp
|
||||
logsumexp = max_val + tl.log(sum_exp)
|
||||
|
||||
# Phase 3: Compute and store logprobs for the top-k tokens
|
||||
token_ids_base = token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||
logprobs_base = (
|
||||
student_logprobs_ptr + batch_idx * stride_lp_b + seq_idx * stride_lp_s
|
||||
)
|
||||
|
||||
for k in range(K):
|
||||
# Load token ID for position k
|
||||
token_id = tl.load(token_ids_base + k * stride_t_k)
|
||||
|
||||
# Load the corresponding logit and convert to float32
|
||||
token_logit_ptr = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ token_id * stride_l_v
|
||||
)
|
||||
token_logit = tl.load(token_logit_ptr).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
token_logit = token_logit / temperature
|
||||
|
||||
# Compute logprob directly: logit - logsumexp
|
||||
token_logprob = token_logit - logsumexp
|
||||
|
||||
# Store the result
|
||||
tl.store(logprobs_base + k * stride_lp_k, token_logprob)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -218,64 +336,61 @@ class TopKKLDivergence(torch.autograd.Function):
|
||||
"""
|
||||
Forward pass for KL divergence loss between top-k logprobs.
|
||||
"""
|
||||
# Convert inputs to appropriate types
|
||||
student_logits = student_logits.float()
|
||||
# Only convert target_logprobs to float, leave student_logits as is
|
||||
target_logprobs = target_logprobs.float()
|
||||
|
||||
# Get dimensions
|
||||
batch_size, _, vocab_size = student_logits.shape
|
||||
_, teacher_seq_len, _ = target_token_ids.shape
|
||||
_, teacher_seq_len, top_k = target_token_ids.shape
|
||||
|
||||
# Slice student logits to match teacher sequence length
|
||||
student_logits_for_kd = student_logits[:, :teacher_seq_len, :]
|
||||
|
||||
if top_k_before_softmax:
|
||||
# 1. Apply temperature to student logits
|
||||
# Apply temperature to student logits
|
||||
if kd_temperature != 1.0:
|
||||
student_logits_for_kd = student_logits_for_kd / kd_temperature
|
||||
|
||||
# 2. Gather student logits for top-k tokens
|
||||
# Gather student logits for top-k tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
)
|
||||
|
||||
# 3. Compute softmax over gathered logits
|
||||
# Compute softmax over gathered logits
|
||||
student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1)
|
||||
student_probs_topk = torch.exp(student_logprobs_topk)
|
||||
else:
|
||||
# 1. Apply temperature to student logits
|
||||
if kd_temperature != 1.0:
|
||||
student_logits_for_kd = student_logits_for_kd / kd_temperature
|
||||
|
||||
# 2. Gather student logits for top-k tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
)
|
||||
|
||||
# 3. Compute logsumexp over full vocabulary using Triton
|
||||
student_lse = torch.empty(
|
||||
(batch_size, teacher_seq_len),
|
||||
# Allocate output tensor for logprobs directly (always in float32)
|
||||
student_logprobs_topk = torch.empty(
|
||||
(batch_size, teacher_seq_len, top_k),
|
||||
dtype=torch.float32,
|
||||
device=student_logits.device,
|
||||
)
|
||||
|
||||
# Launch fused kernel directly
|
||||
grid = (batch_size * teacher_seq_len,)
|
||||
logsumexp_kernel[grid](
|
||||
fused_logsumexp_logprobs_kernel[grid](
|
||||
student_logits_for_kd.contiguous(),
|
||||
student_lse,
|
||||
student_logprobs_topk,
|
||||
target_token_ids.contiguous(),
|
||||
batch_size,
|
||||
teacher_seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
kd_temperature,
|
||||
student_logits_for_kd.stride(0),
|
||||
student_logits_for_kd.stride(1),
|
||||
student_logits_for_kd.stride(2),
|
||||
student_lse.stride(0),
|
||||
student_lse.stride(1),
|
||||
student_logprobs_topk.stride(0),
|
||||
student_logprobs_topk.stride(1),
|
||||
student_logprobs_topk.stride(2),
|
||||
target_token_ids.stride(0),
|
||||
target_token_ids.stride(1),
|
||||
target_token_ids.stride(2),
|
||||
min(1024, triton.next_power_of_2(vocab_size)),
|
||||
)
|
||||
|
||||
# 4. Convert to logprobs
|
||||
student_logprobs_topk = student_logits_topk - student_lse.unsqueeze(-1)
|
||||
# Calculate probs from logprobs
|
||||
student_probs_topk = torch.exp(student_logprobs_topk)
|
||||
|
||||
# Save tensors for backward pass
|
||||
@@ -321,7 +436,7 @@ class TopKKLDivergence(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
Optimized backward pass for KL divergence loss.
|
||||
Optimized backward pass for KL divergence loss with proper dtype handling.
|
||||
"""
|
||||
(
|
||||
student_logits,
|
||||
@@ -333,11 +448,13 @@ class TopKKLDivergence(torch.autograd.Function):
|
||||
kd_temperature = ctx.kd_temperature
|
||||
num_items_in_batch = ctx.num_items_in_batch
|
||||
|
||||
# Store original dtype for later conversion
|
||||
original_dtype = student_logits.dtype
|
||||
batch_size, seq_len, vocab_size = student_logits.shape
|
||||
_, _, top_k = target_token_ids.shape
|
||||
|
||||
# Initialize gradient tensor
|
||||
grad_student_logits = torch.zeros_like(student_logits)
|
||||
# Initialize gradient tensor in float32 to support atomic operations
|
||||
grad_student_logits = torch.zeros_like(student_logits, dtype=torch.float32)
|
||||
|
||||
# Compute scaling factor
|
||||
scale = grad_output.item()
|
||||
@@ -353,15 +470,13 @@ class TopKKLDivergence(torch.autograd.Function):
|
||||
scale = scale / float(target_mask.sum().item())
|
||||
|
||||
# Apply chain rule for temperature scaling (1/temperature)
|
||||
# This comes from d(logits/temperature)/d(logits) = 1/temperature
|
||||
if kd_temperature != 1.0:
|
||||
scale = scale / kd_temperature
|
||||
|
||||
# Convert teacher logprobs to probabilities
|
||||
teacher_probs = torch.exp(target_logprobs)
|
||||
|
||||
# Depending on which mode was used in forward, we use different gradient calculation
|
||||
# FIXME: top_k_before_softmax not correctly yet?
|
||||
# Launch gradient computation kernel
|
||||
grid = (batch_size * seq_len,)
|
||||
grad_softmax_kernel[grid](
|
||||
grad_student_logits.contiguous(),
|
||||
@@ -392,6 +507,10 @@ class TopKKLDivergence(torch.autograd.Function):
|
||||
min(1024, triton.next_power_of_2(top_k)),
|
||||
)
|
||||
|
||||
# Convert gradient back to original dtype if needed
|
||||
if original_dtype != torch.float32:
|
||||
grad_student_logits = grad_student_logits.to(original_dtype)
|
||||
|
||||
# Return gradients for student_logits and None for other inputs
|
||||
return grad_student_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user