chunk to prevent overflows in kernel

This commit is contained in:
Wing Lian
2025-02-26 04:44:24 -05:00
parent 23f029a89c
commit 68e97d032a

View File

@@ -2,6 +2,8 @@
Optimized Triton kernel for KL divergence loss between teacher and student models.
"""
# pylint: disable=invalid-name,unused-argument
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
@@ -316,12 +318,44 @@ def grad_topk_softmax_kernel(
tl.atomic_add(grad_logits_base + other_token_id * stride_gl_v, grad_val)
# Triton-accelerated implementation of KL divergence loss for top-k tokens
# Chunking helper functions for handling long sequences
def chunk_tensor(
tensor: torch.Tensor, max_seq_len: int
) -> Tuple[torch.Tensor, Optional[int]]:
"""Split a tensor along sequence dimension if needed."""
_, seq_len, *__ = tensor.shape
if seq_len <= max_seq_len:
return tensor, None
num_chunks = (seq_len + max_seq_len - 1) // max_seq_len
chunks = []
for i in range(num_chunks):
start_idx = i * max_seq_len
end_idx = min((i + 1) * max_seq_len, seq_len)
chunks.append(tensor[:, start_idx:end_idx, ...])
return chunks, num_chunks
def merge_chunks(chunks: list, original_shape: torch.Size):
"""Merge chunks back into a single tensor with original shape."""
return torch.cat(chunks, dim=1)
# Triton-accelerated implementation of KL divergence loss for top-k tokens
class TopKKLDivergence(torch.autograd.Function):
"""
Autograd function for KL divergence loss between top-k logprobs
with support for chunking to handle very long sequences.
"""
# Max sequence length to process in a single kernel launch
# This is a tunable parameter that might need adjustment based on GPU memory
MAX_SEQ_LEN = 8192
@staticmethod
def forward(
ctx,
@@ -334,7 +368,7 @@ class TopKKLDivergence(torch.autograd.Function):
top_k_before_softmax=0,
):
"""
Forward pass for KL divergence loss between top-k logprobs.
Forward pass for KL divergence loss between top-k logprobs with chunking.
"""
# Only convert target_logprobs to float, leave student_logits as is
target_logprobs = target_logprobs.float()
@@ -346,52 +380,145 @@ class TopKKLDivergence(torch.autograd.Function):
# Slice student logits to match teacher sequence length
student_logits_for_kd = student_logits[:, :teacher_seq_len, :]
if top_k_before_softmax:
# Apply temperature to student logits
if kd_temperature != 1.0:
student_logits_for_kd = student_logits_for_kd / kd_temperature
# Store original values for backward pass
ctx.original_seq_len = teacher_seq_len
ctx.original_dtype = student_logits.dtype
# Gather student logits for top-k tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
# Apply chunking for long sequences
if teacher_seq_len > TopKKLDivergence.MAX_SEQ_LEN:
# Chunk the inputs
student_logits_chunks, num_chunks = chunk_tensor(
student_logits_for_kd, TopKKLDivergence.MAX_SEQ_LEN
)
target_token_ids_chunks, _ = chunk_tensor(
target_token_ids, TopKKLDivergence.MAX_SEQ_LEN
)
# target_logprobs_chunks, _ = chunk_tensor(
# target_logprobs, TopKKLDivergence.MAX_SEQ_LEN
# )
# target_mask_chunks, _ = chunk_tensor(
# target_mask, TopKKLDivergence.MAX_SEQ_LEN
# )
# Process each chunk
student_logprobs_chunks = []
student_probs_chunks = []
for i in range(num_chunks):
chunk_logits = student_logits_chunks[i]
chunk_token_ids = target_token_ids_chunks[i]
chunk_seq_len = chunk_logits.shape[1]
if top_k_before_softmax:
# Apply temperature to student logits
if kd_temperature != 1.0:
chunk_logits = chunk_logits / kd_temperature
# Gather student logits for top-k tokens
chunk_logits_topk = torch.gather(
chunk_logits, dim=-1, index=chunk_token_ids
)
# Compute softmax over gathered logits
chunk_logprobs_topk = torch.log_softmax(chunk_logits_topk, dim=-1)
chunk_probs_topk = torch.exp(chunk_logprobs_topk)
else:
# Allocate output tensor for logprobs directly (always in float32)
chunk_logprobs_topk = torch.empty(
(batch_size, chunk_seq_len, top_k),
dtype=torch.float32,
device=chunk_logits.device,
)
# Launch fused kernel directly
grid = (batch_size * chunk_seq_len,)
fused_logsumexp_logprobs_kernel[grid](
chunk_logits.contiguous(),
chunk_logprobs_topk,
chunk_token_ids.contiguous(),
batch_size,
chunk_seq_len,
vocab_size,
top_k,
kd_temperature,
chunk_logits.stride(0),
chunk_logits.stride(1),
chunk_logits.stride(2),
chunk_logprobs_topk.stride(0),
chunk_logprobs_topk.stride(1),
chunk_logprobs_topk.stride(2),
chunk_token_ids.stride(0),
chunk_token_ids.stride(1),
chunk_token_ids.stride(2),
min(1024, triton.next_power_of_2(vocab_size)),
)
# Calculate probs from logprobs
chunk_probs_topk = torch.exp(chunk_logprobs_topk)
# Store results
student_logprobs_chunks.append(chunk_logprobs_topk)
student_probs_chunks.append(chunk_probs_topk)
# Merge results
student_logprobs_topk = torch.cat(student_logprobs_chunks, dim=1)
student_probs_topk = torch.cat(student_probs_chunks, dim=1)
# Save chunking info for backward pass
ctx.used_chunking = True
ctx.num_chunks = num_chunks
# Compute softmax over gathered logits
student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1)
student_probs_topk = torch.exp(student_logprobs_topk)
else:
# Allocate output tensor for logprobs directly (always in float32)
student_logprobs_topk = torch.empty(
(batch_size, teacher_seq_len, top_k),
dtype=torch.float32,
device=student_logits.device,
)
# Original code path for shorter sequences
if top_k_before_softmax:
# Apply temperature to student logits
if kd_temperature != 1.0:
student_logits_for_kd = student_logits_for_kd / kd_temperature
# Launch fused kernel directly
grid = (batch_size * teacher_seq_len,)
fused_logsumexp_logprobs_kernel[grid](
student_logits_for_kd.contiguous(),
student_logprobs_topk,
target_token_ids.contiguous(),
batch_size,
teacher_seq_len,
vocab_size,
top_k,
kd_temperature,
student_logits_for_kd.stride(0),
student_logits_for_kd.stride(1),
student_logits_for_kd.stride(2),
student_logprobs_topk.stride(0),
student_logprobs_topk.stride(1),
student_logprobs_topk.stride(2),
target_token_ids.stride(0),
target_token_ids.stride(1),
target_token_ids.stride(2),
min(1024, triton.next_power_of_2(vocab_size)),
)
# Gather student logits for top-k tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
)
# Calculate probs from logprobs
student_probs_topk = torch.exp(student_logprobs_topk)
# Compute softmax over gathered logits
student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1)
student_probs_topk = torch.exp(student_logprobs_topk)
else:
# Allocate output tensor for logprobs directly (always in float32)
student_logprobs_topk = torch.empty(
(batch_size, teacher_seq_len, top_k),
dtype=torch.float32,
device=student_logits.device,
)
# Launch fused kernel directly
grid = (batch_size * teacher_seq_len,)
fused_logsumexp_logprobs_kernel[grid](
student_logits_for_kd.contiguous(),
student_logprobs_topk,
target_token_ids.contiguous(),
batch_size,
teacher_seq_len,
vocab_size,
top_k,
kd_temperature,
student_logits_for_kd.stride(0),
student_logits_for_kd.stride(1),
student_logits_for_kd.stride(2),
student_logprobs_topk.stride(0),
student_logprobs_topk.stride(1),
student_logprobs_topk.stride(2),
target_token_ids.stride(0),
target_token_ids.stride(1),
target_token_ids.stride(2),
min(1024, triton.next_power_of_2(vocab_size)),
)
# Calculate probs from logprobs
student_probs_topk = torch.exp(student_logprobs_topk)
# No chunking used
ctx.used_chunking = False
# Save tensors for backward pass
ctx.save_for_backward(
@@ -408,9 +535,20 @@ class TopKKLDivergence(torch.autograd.Function):
# Convert mask to boolean
valid_mask = target_mask.bool()
# Extract valid tokens only
student_logprobs_valid = student_logprobs_topk[valid_mask]
target_logprobs_valid = target_logprobs[valid_mask]
# Extract valid tokens only - this is where the error was happening
# Use cloned contiguous tensors and explicit indexing for safety
student_logprobs_flat = student_logprobs_topk.view(-1, top_k)
target_logprobs_flat = target_logprobs.view(-1, top_k)
valid_mask_flat = valid_mask.view(-1, top_k)
# Gather valid indices explicitly to avoid illegal memory access
valid_indices = torch.nonzero(valid_mask_flat.view(-1)).squeeze(-1)
student_logprobs_valid = torch.index_select(
student_logprobs_flat.view(-1), 0, valid_indices
)
target_logprobs_valid = torch.index_select(
target_logprobs_flat.view(-1), 0, valid_indices
)
# Convert teacher logprobs to probabilities
teacher_probs_valid = torch.exp(target_logprobs_valid)
@@ -430,14 +568,15 @@ class TopKKLDivergence(torch.autograd.Function):
if num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
else:
kd_loss = kd_loss / float(token_losses.size(0))
num_valid_tokens = valid_indices.numel()
kd_loss = kd_loss / float(num_valid_tokens if num_valid_tokens > 0 else 1)
return kd_loss
@staticmethod
def backward(ctx, grad_output):
"""
Optimized backward pass for KL divergence loss with proper dtype handling.
Optimized backward pass for KL divergence loss with proper dtype handling and chunking.
"""
(
student_logits,
@@ -448,11 +587,11 @@ class TopKKLDivergence(torch.autograd.Function):
) = ctx.saved_tensors
kd_temperature = ctx.kd_temperature
num_items_in_batch = ctx.num_items_in_batch
original_dtype = ctx.original_dtype
# Store original dtype for later conversion
original_dtype = student_logits.dtype
batch_size, seq_len, vocab_size = student_logits.shape
_, _, top_k = target_token_ids.shape
# Get dimensions
batch_size, _, vocab_size = student_logits.shape
_, teacher_seq_len, top_k = target_token_ids.shape
# Initialize gradient tensor in float32 to support atomic operations
grad_student_logits = torch.zeros_like(student_logits, dtype=torch.float32)
@@ -477,36 +616,89 @@ class TopKKLDivergence(torch.autograd.Function):
# Convert teacher logprobs to probabilities
teacher_probs = torch.exp(target_logprobs)
# Launch gradient computation kernel
grid = (batch_size * seq_len,)
grad_softmax_kernel[grid](
grad_student_logits.contiguous(),
target_token_ids.contiguous(),
teacher_probs.contiguous(),
student_probs.contiguous(),
target_mask.contiguous(),
batch_size,
seq_len,
vocab_size,
top_k,
scale,
grad_student_logits.stride(0),
grad_student_logits.stride(1),
grad_student_logits.stride(2),
target_token_ids.stride(0),
target_token_ids.stride(1),
target_token_ids.stride(2),
teacher_probs.stride(0),
teacher_probs.stride(1),
teacher_probs.stride(2),
student_probs.stride(0),
student_probs.stride(1),
student_probs.stride(2),
target_mask.stride(0),
target_mask.stride(1),
target_mask.stride(2),
min(1024, triton.next_power_of_2(top_k)),
)
# Use chunking for the backward pass if used in forward
if getattr(ctx, "used_chunking", False):
num_chunks = ctx.num_chunks
max_seq = TopKKLDivergence.MAX_SEQ_LEN
# Process each chunk
for i in range(num_chunks):
start_idx = i * max_seq
end_idx = min((i + 1) * max_seq, teacher_seq_len)
chunk_len = end_idx - start_idx
# Get chunk slices
# student_logits_chunk = student_logits[:, start_idx:end_idx, :]
target_token_ids_chunk = target_token_ids[:, start_idx:end_idx, :]
teacher_probs_chunk = teacher_probs[:, start_idx:end_idx, :]
student_probs_chunk = student_probs[:, start_idx:end_idx, :]
target_mask_chunk = target_mask[:, start_idx:end_idx, :]
grad_student_logits_chunk = grad_student_logits[:, start_idx:end_idx, :]
# Launch gradient computation kernel for this chunk
grid = (batch_size * chunk_len,)
grad_softmax_kernel[grid](
grad_student_logits_chunk.contiguous(),
target_token_ids_chunk.contiguous(),
teacher_probs_chunk.contiguous(),
student_probs_chunk.contiguous(),
target_mask_chunk.contiguous(),
batch_size,
chunk_len,
vocab_size,
top_k,
scale,
grad_student_logits_chunk.stride(0),
grad_student_logits_chunk.stride(1),
grad_student_logits_chunk.stride(2),
target_token_ids_chunk.stride(0),
target_token_ids_chunk.stride(1),
target_token_ids_chunk.stride(2),
teacher_probs_chunk.stride(0),
teacher_probs_chunk.stride(1),
teacher_probs_chunk.stride(2),
student_probs_chunk.stride(0),
student_probs_chunk.stride(1),
student_probs_chunk.stride(2),
target_mask_chunk.stride(0),
target_mask_chunk.stride(1),
target_mask_chunk.stride(2),
min(1024, triton.next_power_of_2(top_k)),
)
# Update the gradient tensor (already in-place)
else:
# Original code path for shorter sequences
# Launch gradient computation kernel
grid = (batch_size * teacher_seq_len,)
grad_softmax_kernel[grid](
grad_student_logits.contiguous(),
target_token_ids.contiguous(),
teacher_probs.contiguous(),
student_probs.contiguous(),
target_mask.contiguous(),
batch_size,
teacher_seq_len,
vocab_size,
top_k,
scale,
grad_student_logits.stride(0),
grad_student_logits.stride(1),
grad_student_logits.stride(2),
target_token_ids.stride(0),
target_token_ids.stride(1),
target_token_ids.stride(2),
teacher_probs.stride(0),
teacher_probs.stride(1),
teacher_probs.stride(2),
student_probs.stride(0),
student_probs.stride(1),
student_probs.stride(2),
target_mask.stride(0),
target_mask.stride(1),
target_mask.stride(2),
min(1024, triton.next_power_of_2(top_k)),
)
# Convert gradient back to original dtype if needed
if original_dtype != torch.float32:
@@ -525,9 +717,11 @@ def loss(
num_items_in_batch: int = -1,
kd_temperature: float = 1.0,
top_k_before_softmax: int = 0,
max_seq_len: Optional[int] = None,
):
"""
Triton-accelerated Memory-efficient KL divergence loss computation for knowledge distillation
with support for very long sequences.
Args:
student_logits: Student logits [B, seq_len, vocab_size]
@@ -537,7 +731,12 @@ def loss(
num_items_in_batch: Number of items for normalization (-1 for auto)
kd_temperature: Temperature for KD
top_k_before_softmax: Flag for softmax application order
max_seq_len: Override default MAX_SEQ_LEN value for chunking
"""
# Allow overriding the max sequence length
if max_seq_len is not None and max_seq_len > 0:
TopKKLDivergence.MAX_SEQ_LEN = max_seq_len
total_loss = TopKKLDivergence.apply(
student_logits,
target_token_ids,