more optims
This commit is contained in:
@@ -6,7 +6,125 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from .logsumexp import logsumexp_kernel
|
|
||||||
|
@triton.jit
|
||||||
|
def fused_logsumexp_logprobs_kernel(
|
||||||
|
student_logits_ptr, # Input logits in original dtype
|
||||||
|
student_logprobs_ptr, # Output logprobs (float32)
|
||||||
|
token_ids_ptr, # Token IDs for top-k
|
||||||
|
B,
|
||||||
|
S,
|
||||||
|
V,
|
||||||
|
K, # batch size, seq len, vocab size, top-k
|
||||||
|
temperature,
|
||||||
|
stride_l_b,
|
||||||
|
stride_l_s,
|
||||||
|
stride_l_v,
|
||||||
|
stride_lp_b,
|
||||||
|
stride_lp_s,
|
||||||
|
stride_lp_k,
|
||||||
|
stride_t_b,
|
||||||
|
stride_t_s,
|
||||||
|
stride_t_k,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Fused kernel that computes logsumexp and logprobs for topk tokens.
|
||||||
|
All computations are done in float32 for numerical stability.
|
||||||
|
"""
|
||||||
|
# Program ID
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
batch_idx = pid // S
|
||||||
|
seq_idx = pid % S
|
||||||
|
|
||||||
|
# Bounds check
|
||||||
|
if batch_idx >= B or seq_idx >= S:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Compute logsumexp over the vocabulary
|
||||||
|
max_val = -float("inf")
|
||||||
|
|
||||||
|
# Phase 1: Find max value across vocabulary
|
||||||
|
for v_offset in range(0, V, BLOCK_SIZE):
|
||||||
|
# Create block indices and mask
|
||||||
|
block_size = min(BLOCK_SIZE, V - v_offset)
|
||||||
|
block_idx = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = block_idx < block_size
|
||||||
|
|
||||||
|
# Load logits block and convert to float32 in-place
|
||||||
|
ptrs = (
|
||||||
|
student_logits_ptr
|
||||||
|
+ batch_idx * stride_l_b
|
||||||
|
+ seq_idx * stride_l_s
|
||||||
|
+ (v_offset + block_idx) * stride_l_v
|
||||||
|
)
|
||||||
|
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
|
||||||
|
|
||||||
|
# Apply temperature scaling if needed
|
||||||
|
if temperature != 1.0:
|
||||||
|
block_logits = block_logits / temperature
|
||||||
|
|
||||||
|
# Update max value
|
||||||
|
block_max = tl.max(block_logits, axis=0)
|
||||||
|
max_val = tl.maximum(max_val, block_max)
|
||||||
|
|
||||||
|
# Phase 2: Compute sum of exp(logits - max_val)
|
||||||
|
sum_exp = 0.0
|
||||||
|
|
||||||
|
for v_offset in range(0, V, BLOCK_SIZE):
|
||||||
|
# Create block indices and mask
|
||||||
|
block_size = min(BLOCK_SIZE, V - v_offset)
|
||||||
|
block_idx = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = block_idx < block_size
|
||||||
|
|
||||||
|
# Load logits block and convert to float32 in-place
|
||||||
|
ptrs = (
|
||||||
|
student_logits_ptr
|
||||||
|
+ batch_idx * stride_l_b
|
||||||
|
+ seq_idx * stride_l_s
|
||||||
|
+ (v_offset + block_idx) * stride_l_v
|
||||||
|
)
|
||||||
|
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
|
||||||
|
|
||||||
|
# Apply temperature scaling if needed
|
||||||
|
if temperature != 1.0:
|
||||||
|
block_logits = block_logits / temperature
|
||||||
|
|
||||||
|
# Compute exp(logits - max_val) and add to sum
|
||||||
|
block_exp = tl.exp(block_logits - max_val)
|
||||||
|
sum_exp += tl.sum(block_exp * mask, axis=0)
|
||||||
|
|
||||||
|
# Compute final logsumexp
|
||||||
|
logsumexp = max_val + tl.log(sum_exp)
|
||||||
|
|
||||||
|
# Phase 3: Compute and store logprobs for the top-k tokens
|
||||||
|
token_ids_base = token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||||
|
logprobs_base = (
|
||||||
|
student_logprobs_ptr + batch_idx * stride_lp_b + seq_idx * stride_lp_s
|
||||||
|
)
|
||||||
|
|
||||||
|
for k in range(K):
|
||||||
|
# Load token ID for position k
|
||||||
|
token_id = tl.load(token_ids_base + k * stride_t_k)
|
||||||
|
|
||||||
|
# Load the corresponding logit and convert to float32
|
||||||
|
token_logit_ptr = (
|
||||||
|
student_logits_ptr
|
||||||
|
+ batch_idx * stride_l_b
|
||||||
|
+ seq_idx * stride_l_s
|
||||||
|
+ token_id * stride_l_v
|
||||||
|
)
|
||||||
|
token_logit = tl.load(token_logit_ptr).to(tl.float32)
|
||||||
|
|
||||||
|
# Apply temperature scaling if needed
|
||||||
|
if temperature != 1.0:
|
||||||
|
token_logit = token_logit / temperature
|
||||||
|
|
||||||
|
# Compute logprob directly: logit - logsumexp
|
||||||
|
token_logprob = token_logit - logsumexp
|
||||||
|
|
||||||
|
# Store the result
|
||||||
|
tl.store(logprobs_base + k * stride_lp_k, token_logprob)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -218,64 +336,61 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
"""
|
"""
|
||||||
Forward pass for KL divergence loss between top-k logprobs.
|
Forward pass for KL divergence loss between top-k logprobs.
|
||||||
"""
|
"""
|
||||||
# Convert inputs to appropriate types
|
# Only convert target_logprobs to float, leave student_logits as is
|
||||||
student_logits = student_logits.float()
|
|
||||||
target_logprobs = target_logprobs.float()
|
target_logprobs = target_logprobs.float()
|
||||||
|
|
||||||
# Get dimensions
|
# Get dimensions
|
||||||
batch_size, _, vocab_size = student_logits.shape
|
batch_size, _, vocab_size = student_logits.shape
|
||||||
_, teacher_seq_len, _ = target_token_ids.shape
|
_, teacher_seq_len, top_k = target_token_ids.shape
|
||||||
|
|
||||||
# Slice student logits to match teacher sequence length
|
# Slice student logits to match teacher sequence length
|
||||||
student_logits_for_kd = student_logits[:, :teacher_seq_len, :]
|
student_logits_for_kd = student_logits[:, :teacher_seq_len, :]
|
||||||
|
|
||||||
if top_k_before_softmax:
|
if top_k_before_softmax:
|
||||||
# 1. Apply temperature to student logits
|
# Apply temperature to student logits
|
||||||
if kd_temperature != 1.0:
|
if kd_temperature != 1.0:
|
||||||
student_logits_for_kd = student_logits_for_kd / kd_temperature
|
student_logits_for_kd = student_logits_for_kd / kd_temperature
|
||||||
|
|
||||||
# 2. Gather student logits for top-k tokens
|
# Gather student logits for top-k tokens
|
||||||
student_logits_topk = torch.gather(
|
student_logits_topk = torch.gather(
|
||||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Compute softmax over gathered logits
|
# Compute softmax over gathered logits
|
||||||
student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1)
|
student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1)
|
||||||
student_probs_topk = torch.exp(student_logprobs_topk)
|
student_probs_topk = torch.exp(student_logprobs_topk)
|
||||||
else:
|
else:
|
||||||
# 1. Apply temperature to student logits
|
# Allocate output tensor for logprobs directly (always in float32)
|
||||||
if kd_temperature != 1.0:
|
student_logprobs_topk = torch.empty(
|
||||||
student_logits_for_kd = student_logits_for_kd / kd_temperature
|
(batch_size, teacher_seq_len, top_k),
|
||||||
|
|
||||||
# 2. Gather student logits for top-k tokens
|
|
||||||
student_logits_topk = torch.gather(
|
|
||||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Compute logsumexp over full vocabulary using Triton
|
|
||||||
student_lse = torch.empty(
|
|
||||||
(batch_size, teacher_seq_len),
|
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=student_logits.device,
|
device=student_logits.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Launch fused kernel directly
|
||||||
grid = (batch_size * teacher_seq_len,)
|
grid = (batch_size * teacher_seq_len,)
|
||||||
logsumexp_kernel[grid](
|
fused_logsumexp_logprobs_kernel[grid](
|
||||||
student_logits_for_kd.contiguous(),
|
student_logits_for_kd.contiguous(),
|
||||||
student_lse,
|
student_logprobs_topk,
|
||||||
|
target_token_ids.contiguous(),
|
||||||
batch_size,
|
batch_size,
|
||||||
teacher_seq_len,
|
teacher_seq_len,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
|
top_k,
|
||||||
|
kd_temperature,
|
||||||
student_logits_for_kd.stride(0),
|
student_logits_for_kd.stride(0),
|
||||||
student_logits_for_kd.stride(1),
|
student_logits_for_kd.stride(1),
|
||||||
student_logits_for_kd.stride(2),
|
student_logits_for_kd.stride(2),
|
||||||
student_lse.stride(0),
|
student_logprobs_topk.stride(0),
|
||||||
student_lse.stride(1),
|
student_logprobs_topk.stride(1),
|
||||||
|
student_logprobs_topk.stride(2),
|
||||||
|
target_token_ids.stride(0),
|
||||||
|
target_token_ids.stride(1),
|
||||||
|
target_token_ids.stride(2),
|
||||||
min(1024, triton.next_power_of_2(vocab_size)),
|
min(1024, triton.next_power_of_2(vocab_size)),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Convert to logprobs
|
# Calculate probs from logprobs
|
||||||
student_logprobs_topk = student_logits_topk - student_lse.unsqueeze(-1)
|
|
||||||
student_probs_topk = torch.exp(student_logprobs_topk)
|
student_probs_topk = torch.exp(student_logprobs_topk)
|
||||||
|
|
||||||
# Save tensors for backward pass
|
# Save tensors for backward pass
|
||||||
@@ -321,7 +436,7 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
"""
|
"""
|
||||||
Optimized backward pass for KL divergence loss.
|
Optimized backward pass for KL divergence loss with proper dtype handling.
|
||||||
"""
|
"""
|
||||||
(
|
(
|
||||||
student_logits,
|
student_logits,
|
||||||
@@ -333,11 +448,13 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
kd_temperature = ctx.kd_temperature
|
kd_temperature = ctx.kd_temperature
|
||||||
num_items_in_batch = ctx.num_items_in_batch
|
num_items_in_batch = ctx.num_items_in_batch
|
||||||
|
|
||||||
|
# Store original dtype for later conversion
|
||||||
|
original_dtype = student_logits.dtype
|
||||||
batch_size, seq_len, vocab_size = student_logits.shape
|
batch_size, seq_len, vocab_size = student_logits.shape
|
||||||
_, _, top_k = target_token_ids.shape
|
_, _, top_k = target_token_ids.shape
|
||||||
|
|
||||||
# Initialize gradient tensor
|
# Initialize gradient tensor in float32 to support atomic operations
|
||||||
grad_student_logits = torch.zeros_like(student_logits)
|
grad_student_logits = torch.zeros_like(student_logits, dtype=torch.float32)
|
||||||
|
|
||||||
# Compute scaling factor
|
# Compute scaling factor
|
||||||
scale = grad_output.item()
|
scale = grad_output.item()
|
||||||
@@ -353,15 +470,13 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
scale = scale / float(target_mask.sum().item())
|
scale = scale / float(target_mask.sum().item())
|
||||||
|
|
||||||
# Apply chain rule for temperature scaling (1/temperature)
|
# Apply chain rule for temperature scaling (1/temperature)
|
||||||
# This comes from d(logits/temperature)/d(logits) = 1/temperature
|
|
||||||
if kd_temperature != 1.0:
|
if kd_temperature != 1.0:
|
||||||
scale = scale / kd_temperature
|
scale = scale / kd_temperature
|
||||||
|
|
||||||
# Convert teacher logprobs to probabilities
|
# Convert teacher logprobs to probabilities
|
||||||
teacher_probs = torch.exp(target_logprobs)
|
teacher_probs = torch.exp(target_logprobs)
|
||||||
|
|
||||||
# Depending on which mode was used in forward, we use different gradient calculation
|
# Launch gradient computation kernel
|
||||||
# FIXME: top_k_before_softmax not correctly yet?
|
|
||||||
grid = (batch_size * seq_len,)
|
grid = (batch_size * seq_len,)
|
||||||
grad_softmax_kernel[grid](
|
grad_softmax_kernel[grid](
|
||||||
grad_student_logits.contiguous(),
|
grad_student_logits.contiguous(),
|
||||||
@@ -392,6 +507,10 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
min(1024, triton.next_power_of_2(top_k)),
|
min(1024, triton.next_power_of_2(top_k)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Convert gradient back to original dtype if needed
|
||||||
|
if original_dtype != torch.float32:
|
||||||
|
grad_student_logits = grad_student_logits.to(original_dtype)
|
||||||
|
|
||||||
# Return gradients for student_logits and None for other inputs
|
# Return gradients for student_logits and None for other inputs
|
||||||
return grad_student_logits, None, None, None, None, None, None
|
return grad_student_logits, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user