fix gradients
This commit is contained in:
@@ -6,73 +6,15 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
from .logsumexp import logsumexp_kernel
|
||||||
# Helper function for computing logsumexp
|
|
||||||
@triton.jit
|
|
||||||
def logsumexp_kernel(
|
|
||||||
logits_ptr,
|
|
||||||
output_ptr,
|
|
||||||
B,
|
|
||||||
S,
|
|
||||||
V, # batch size, seq len, vocab size
|
|
||||||
stride_b,
|
|
||||||
stride_s,
|
|
||||||
stride_v,
|
|
||||||
out_stride_b,
|
|
||||||
out_stride_s,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
# Program ID
|
|
||||||
pid = tl.program_id(0)
|
|
||||||
batch_idx = pid // S
|
|
||||||
seq_idx = pid % S
|
|
||||||
|
|
||||||
# Bounds check
|
|
||||||
if batch_idx >= B or seq_idx >= S:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Pointers
|
|
||||||
logits_base = logits_ptr + batch_idx * stride_b + seq_idx * stride_s
|
|
||||||
|
|
||||||
# Find maximum for numerical stability
|
|
||||||
max_val = -float("inf")
|
|
||||||
for v_offset in range(0, V, BLOCK_SIZE):
|
|
||||||
v_size = min(BLOCK_SIZE, V - v_offset)
|
|
||||||
mask = tl.arange(0, BLOCK_SIZE) < v_size
|
|
||||||
|
|
||||||
logits_block = tl.load(
|
|
||||||
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
|
|
||||||
mask=mask,
|
|
||||||
other=-float("inf"),
|
|
||||||
)
|
|
||||||
max_val = tl.maximum(max_val, tl.max(logits_block, axis=0))
|
|
||||||
|
|
||||||
# Compute sum of exp(logit - max_val)
|
|
||||||
sum_exp = 0.0
|
|
||||||
for v_offset in range(0, V, BLOCK_SIZE):
|
|
||||||
v_size = min(BLOCK_SIZE, V - v_offset)
|
|
||||||
mask = tl.arange(0, BLOCK_SIZE) < v_size
|
|
||||||
|
|
||||||
logits_block = tl.load(
|
|
||||||
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
|
|
||||||
mask=mask,
|
|
||||||
other=-float("inf"),
|
|
||||||
)
|
|
||||||
sum_exp += tl.sum(tl.exp(logits_block - max_val), axis=0)
|
|
||||||
|
|
||||||
# Compute logsumexp
|
|
||||||
result = max_val + tl.log(sum_exp)
|
|
||||||
|
|
||||||
# Store result
|
|
||||||
tl.store(output_ptr + batch_idx * out_stride_b + seq_idx * out_stride_s, result)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def grad_softmax_kernel(
|
def grad_softmax_kernel(
|
||||||
grad_student_logits_ptr,
|
grad_student_logits_ptr,
|
||||||
student_logits_ptr,
|
|
||||||
target_token_ids_ptr,
|
target_token_ids_ptr,
|
||||||
teacher_probs_ptr,
|
teacher_probs_ptr,
|
||||||
|
student_probs_ptr,
|
||||||
mask_ptr,
|
mask_ptr,
|
||||||
B,
|
B,
|
||||||
S,
|
S,
|
||||||
@@ -82,15 +24,15 @@ def grad_softmax_kernel(
|
|||||||
stride_gl_b,
|
stride_gl_b,
|
||||||
stride_gl_s,
|
stride_gl_s,
|
||||||
stride_gl_v,
|
stride_gl_v,
|
||||||
stride_l_b,
|
|
||||||
stride_l_s,
|
|
||||||
stride_l_v,
|
|
||||||
stride_t_b,
|
stride_t_b,
|
||||||
stride_t_s,
|
stride_t_s,
|
||||||
stride_t_k,
|
stride_t_k,
|
||||||
stride_p_b,
|
stride_p_b,
|
||||||
stride_p_s,
|
stride_p_s,
|
||||||
stride_p_k,
|
stride_p_k,
|
||||||
|
stride_sp_b,
|
||||||
|
stride_sp_s,
|
||||||
|
stride_sp_k,
|
||||||
stride_m_b,
|
stride_m_b,
|
||||||
stride_m_s,
|
stride_m_s,
|
||||||
stride_m_k,
|
stride_m_k,
|
||||||
@@ -116,21 +58,31 @@ def grad_softmax_kernel(
|
|||||||
teacher_probs_base = (
|
teacher_probs_base = (
|
||||||
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
|
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
|
||||||
)
|
)
|
||||||
|
student_probs_base = (
|
||||||
|
student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s
|
||||||
|
)
|
||||||
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
|
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
|
||||||
|
|
||||||
# Softmax over full vocab case
|
# Softmax over full vocab case
|
||||||
for k in range(0, K):
|
for k in range(0, K):
|
||||||
# Load token ID, teacher prob, and mask for this position
|
# Load token ID, teacher prob, and mask for this position
|
||||||
token_id = tl.load(token_ids_base + k * stride_t_k)
|
|
||||||
teacher_prob = tl.load(teacher_probs_base + k * stride_p_k)
|
teacher_prob = tl.load(teacher_probs_base + k * stride_p_k)
|
||||||
|
student_prob_k = tl.load(student_probs_base + k * stride_sp_k)
|
||||||
mask_val = tl.load(mask_base + k * stride_m_k)
|
mask_val = tl.load(mask_base + k * stride_m_k)
|
||||||
|
for j in range(0, K):
|
||||||
# Apply mask by scaling gradient to zero if masked
|
other_token_id = tl.load(token_ids_base + j * stride_t_k)
|
||||||
grad_val = teacher_prob * scale * mask_val
|
student_prob_j = tl.load(student_probs_base + j * stride_sp_k)
|
||||||
|
mask_j = tl.load(mask_base + j * stride_m_k)
|
||||||
# Update the gradient for this token's position in the vocabulary
|
combined_mask = mask_val * mask_j
|
||||||
# Only contributes if mask_val is non-zero
|
is_diagonal = tl.where(j == k, 1.0, 0.0)
|
||||||
tl.atomic_add(grad_logits_base + token_id * stride_gl_v, grad_val)
|
self_grad = teacher_prob * (1.0 - student_prob_k)
|
||||||
|
cross_grad = -teacher_prob * student_prob_j
|
||||||
|
grad_val = (
|
||||||
|
-(self_grad * is_diagonal + cross_grad * (1.0 - is_diagonal))
|
||||||
|
* scale
|
||||||
|
* combined_mask
|
||||||
|
)
|
||||||
|
tl.atomic_add(grad_logits_base + other_token_id * stride_gl_v, grad_val)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -338,7 +290,6 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
)
|
)
|
||||||
kd_loss = token_losses.sum()
|
kd_loss = token_losses.sum()
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
# Apply temperature scaling
|
# Apply temperature scaling
|
||||||
if kd_temperature != 1.0:
|
if kd_temperature != 1.0:
|
||||||
kd_loss = kd_loss * (kd_temperature**2)
|
kd_loss = kd_loss * (kd_temperature**2)
|
||||||
@@ -376,7 +327,7 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
# Compute scaling factor
|
# Compute scaling factor
|
||||||
scale = grad_output.item()
|
scale = grad_output.item()
|
||||||
|
|
||||||
# Apply temperature scaling
|
# Apply temperature scaling from forward pass
|
||||||
if kd_temperature != 1.0:
|
if kd_temperature != 1.0:
|
||||||
scale = scale * (kd_temperature**2)
|
scale = scale * (kd_temperature**2)
|
||||||
|
|
||||||
@@ -386,7 +337,8 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
else:
|
else:
|
||||||
scale = scale / float(target_mask.sum().item())
|
scale = scale / float(target_mask.sum().item())
|
||||||
|
|
||||||
# If we used temperature scaling in the forward pass, we need to apply it in the backward pass
|
# Apply chain rule for temperature scaling (1/temperature)
|
||||||
|
# This comes from d(logits/temperature)/d(logits) = 1/temperature
|
||||||
if kd_temperature != 1.0:
|
if kd_temperature != 1.0:
|
||||||
scale = scale / kd_temperature
|
scale = scale / kd_temperature
|
||||||
|
|
||||||
@@ -434,9 +386,9 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
grid = (batch_size * seq_len,)
|
grid = (batch_size * seq_len,)
|
||||||
grad_softmax_kernel[grid](
|
grad_softmax_kernel[grid](
|
||||||
grad_student_logits.contiguous(),
|
grad_student_logits.contiguous(),
|
||||||
student_logits.contiguous(),
|
|
||||||
target_token_ids.contiguous(),
|
target_token_ids.contiguous(),
|
||||||
teacher_probs.contiguous(),
|
teacher_probs.contiguous(),
|
||||||
|
student_probs.contiguous(),
|
||||||
target_mask.contiguous(),
|
target_mask.contiguous(),
|
||||||
batch_size,
|
batch_size,
|
||||||
seq_len,
|
seq_len,
|
||||||
@@ -446,15 +398,15 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
grad_student_logits.stride(0),
|
grad_student_logits.stride(0),
|
||||||
grad_student_logits.stride(1),
|
grad_student_logits.stride(1),
|
||||||
grad_student_logits.stride(2),
|
grad_student_logits.stride(2),
|
||||||
student_logits.stride(0),
|
|
||||||
student_logits.stride(1),
|
|
||||||
student_logits.stride(2),
|
|
||||||
target_token_ids.stride(0),
|
target_token_ids.stride(0),
|
||||||
target_token_ids.stride(1),
|
target_token_ids.stride(1),
|
||||||
target_token_ids.stride(2),
|
target_token_ids.stride(2),
|
||||||
teacher_probs.stride(0),
|
teacher_probs.stride(0),
|
||||||
teacher_probs.stride(1),
|
teacher_probs.stride(1),
|
||||||
teacher_probs.stride(2),
|
teacher_probs.stride(2),
|
||||||
|
student_probs.stride(0),
|
||||||
|
student_probs.stride(1),
|
||||||
|
student_probs.stride(2),
|
||||||
target_mask.stride(0),
|
target_mask.stride(0),
|
||||||
target_mask.stride(1),
|
target_mask.stride(1),
|
||||||
target_mask.stride(2),
|
target_mask.stride(2),
|
||||||
|
|||||||
67
src/axolotl/integrations/kd/topk_logprob/logsumexp.py
Normal file
67
src/axolotl/integrations/kd/topk_logprob/logsumexp.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
Optimized Triton kernels for logsumexp
|
||||||
|
"""
|
||||||
|
# pylint: disable=invalid-name,unused-argument
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function for computing logsumexp
|
||||||
|
@triton.jit
|
||||||
|
def logsumexp_kernel(
|
||||||
|
logits_ptr,
|
||||||
|
output_ptr,
|
||||||
|
B,
|
||||||
|
S,
|
||||||
|
V, # batch size, seq len, vocab size
|
||||||
|
stride_b,
|
||||||
|
stride_s,
|
||||||
|
stride_v,
|
||||||
|
out_stride_b,
|
||||||
|
out_stride_s,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
# Program ID
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
batch_idx = pid // S
|
||||||
|
seq_idx = pid % S
|
||||||
|
|
||||||
|
# Bounds check
|
||||||
|
if batch_idx >= B or seq_idx >= S:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Pointers
|
||||||
|
logits_base = logits_ptr + batch_idx * stride_b + seq_idx * stride_s
|
||||||
|
|
||||||
|
# Find maximum for numerical stability
|
||||||
|
max_val = -float("inf")
|
||||||
|
for v_offset in range(0, V, BLOCK_SIZE):
|
||||||
|
v_size = min(BLOCK_SIZE, V - v_offset)
|
||||||
|
mask = tl.arange(0, BLOCK_SIZE) < v_size
|
||||||
|
|
||||||
|
logits_block = tl.load(
|
||||||
|
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
|
||||||
|
mask=mask,
|
||||||
|
other=-float("inf"),
|
||||||
|
)
|
||||||
|
max_val = tl.maximum(max_val, tl.max(logits_block, axis=0))
|
||||||
|
|
||||||
|
# Compute sum of exp(logit - max_val)
|
||||||
|
sum_exp = 0.0
|
||||||
|
for v_offset in range(0, V, BLOCK_SIZE):
|
||||||
|
v_size = min(BLOCK_SIZE, V - v_offset)
|
||||||
|
mask = tl.arange(0, BLOCK_SIZE) < v_size
|
||||||
|
|
||||||
|
logits_block = tl.load(
|
||||||
|
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
|
||||||
|
mask=mask,
|
||||||
|
other=-float("inf"),
|
||||||
|
)
|
||||||
|
sum_exp += tl.sum(tl.exp(logits_block - max_val), axis=0)
|
||||||
|
|
||||||
|
# Compute logsumexp
|
||||||
|
result = max_val + tl.log(sum_exp)
|
||||||
|
|
||||||
|
# Store result
|
||||||
|
tl.store(output_ptr + batch_idx * out_stride_b + seq_idx * out_stride_s, result)
|
||||||
162
tests/e2e/integrations/test_kl_loss.py
Normal file
162
tests/e2e/integrations/test_kl_loss.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
sanity checks on kl loss and gradients
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Import both implementations
|
||||||
|
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
|
||||||
|
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
|
||||||
|
|
||||||
|
|
||||||
|
def test_kl_loss_gradient():
|
||||||
|
"""Test that the gradient of the Triton implementation matches the eager implementation."""
|
||||||
|
|
||||||
|
# Set the random seed for reproducibility
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
# Create random inputs
|
||||||
|
batch_size = 2
|
||||||
|
seq_len = 3
|
||||||
|
vocab_size = 100
|
||||||
|
top_k = 5
|
||||||
|
|
||||||
|
# Generate random student logits
|
||||||
|
student_logits = torch.randn(
|
||||||
|
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
|
||||||
|
)
|
||||||
|
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
||||||
|
|
||||||
|
# Generate random target token IDs, ensuring they're valid indices
|
||||||
|
target_token_ids = torch.randint(
|
||||||
|
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate random target logprobs (before normalization)
|
||||||
|
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||||
|
|
||||||
|
# Normalize the target logprobs to ensure they form a valid distribution
|
||||||
|
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||||
|
|
||||||
|
# Create a random mask with some tokens masked out
|
||||||
|
target_mask = torch.randint(
|
||||||
|
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||||
|
).float()
|
||||||
|
|
||||||
|
# Additional parameters
|
||||||
|
num_items_in_batch = batch_size * seq_len
|
||||||
|
kd_temperature = 1.0
|
||||||
|
top_k_before_softmax = 0 # Test both modes
|
||||||
|
|
||||||
|
# Compute the loss and gradients with eager implementation
|
||||||
|
loss_eager = eager_loss(
|
||||||
|
student_logits,
|
||||||
|
target_token_ids,
|
||||||
|
target_logprobs,
|
||||||
|
target_mask,
|
||||||
|
num_items_in_batch,
|
||||||
|
kd_temperature,
|
||||||
|
top_k_before_softmax,
|
||||||
|
)
|
||||||
|
loss_eager.backward()
|
||||||
|
grad_eager = student_logits.grad.clone()
|
||||||
|
|
||||||
|
# Reset gradients
|
||||||
|
student_logits.grad.zero_()
|
||||||
|
|
||||||
|
# Compute the loss and gradients with Triton implementation
|
||||||
|
loss_triton = triton_loss(
|
||||||
|
student_logits_triton,
|
||||||
|
target_token_ids,
|
||||||
|
target_logprobs,
|
||||||
|
target_mask,
|
||||||
|
num_items_in_batch,
|
||||||
|
kd_temperature,
|
||||||
|
top_k_before_softmax,
|
||||||
|
)
|
||||||
|
loss_triton.backward()
|
||||||
|
grad_triton = student_logits_triton.grad.clone()
|
||||||
|
|
||||||
|
# Compare loss values
|
||||||
|
print(f"Eager loss: {loss_eager.item()}")
|
||||||
|
print(f"Triton loss: {loss_triton.item()}")
|
||||||
|
loss_diff = abs(loss_eager.item() - loss_triton.item())
|
||||||
|
print(f"Loss difference: {loss_diff}")
|
||||||
|
assert loss_diff < 1e-5, "Loss values differ significantly!"
|
||||||
|
|
||||||
|
# Compare gradients
|
||||||
|
grad_diff = (grad_eager - grad_triton).abs().max().item()
|
||||||
|
print(f"Max gradient difference: {grad_diff}")
|
||||||
|
|
||||||
|
# Print some sample gradients
|
||||||
|
sample_idx = (0, 0, 0) # (batch, seq, vocab)
|
||||||
|
print(f"Sample eager gradient: {grad_eager[sample_idx].item()}")
|
||||||
|
print(f"Sample triton gradient: {grad_triton[sample_idx].item()}")
|
||||||
|
|
||||||
|
# Compute relative difference for non-zero gradients
|
||||||
|
mask = grad_eager.abs() > 1e-10
|
||||||
|
if mask.sum() > 0:
|
||||||
|
rel_diff = (
|
||||||
|
(
|
||||||
|
(grad_eager[mask] - grad_triton[mask]).abs()
|
||||||
|
/ (grad_eager[mask].abs() + 1e-10)
|
||||||
|
)
|
||||||
|
.max()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
print(f"Max relative gradient difference: {rel_diff}")
|
||||||
|
assert rel_diff < 1e-3, "Gradients differ significantly!"
|
||||||
|
|
||||||
|
# Also test top_k_before_softmax = 1 mode
|
||||||
|
top_k_before_softmax = 1
|
||||||
|
|
||||||
|
# Reset the gradients
|
||||||
|
student_logits = torch.randn(
|
||||||
|
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
|
||||||
|
)
|
||||||
|
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
||||||
|
|
||||||
|
# Compute the loss and gradients with eager implementation
|
||||||
|
loss_eager = eager_loss(
|
||||||
|
student_logits,
|
||||||
|
target_token_ids,
|
||||||
|
target_logprobs,
|
||||||
|
target_mask,
|
||||||
|
num_items_in_batch,
|
||||||
|
kd_temperature,
|
||||||
|
top_k_before_softmax,
|
||||||
|
)
|
||||||
|
loss_eager.backward()
|
||||||
|
grad_eager = student_logits.grad.clone()
|
||||||
|
|
||||||
|
# Compute the loss and gradients with Triton implementation
|
||||||
|
loss_triton = triton_loss(
|
||||||
|
student_logits_triton,
|
||||||
|
target_token_ids,
|
||||||
|
target_logprobs,
|
||||||
|
target_mask,
|
||||||
|
num_items_in_batch,
|
||||||
|
kd_temperature,
|
||||||
|
top_k_before_softmax,
|
||||||
|
)
|
||||||
|
loss_triton.backward()
|
||||||
|
grad_triton = student_logits_triton.grad.clone()
|
||||||
|
|
||||||
|
# Compare gradients for top_k_before_softmax = 1
|
||||||
|
grad_diff = (grad_eager - grad_triton).abs().max().item()
|
||||||
|
print("\nWith top_k_before_softmax=1:")
|
||||||
|
print(f"Max gradient difference: {grad_diff}")
|
||||||
|
|
||||||
|
# Compute relative difference for non-zero gradients
|
||||||
|
mask = grad_eager.abs() > 1e-10
|
||||||
|
if mask.sum() > 0:
|
||||||
|
rel_diff = (
|
||||||
|
(
|
||||||
|
(grad_eager[mask] - grad_triton[mask]).abs()
|
||||||
|
/ (grad_eager[mask].abs() + 1e-10)
|
||||||
|
)
|
||||||
|
.max()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
rel_diff < 1e-3
|
||||||
|
), f"Gradients differ significantly, Max relative gradient difference: {rel_diff}"
|
||||||
204
tests/e2e/integrations/test_logsumexp.py
Normal file
204
tests/e2e/integrations/test_logsumexp.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
"""
|
||||||
|
sanity checks on logsumexp kernel validity
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
|
||||||
|
from axolotl.integrations.kd.topk_logprob.logsumexp import logsumexp_kernel
|
||||||
|
|
||||||
|
|
||||||
|
# PyTorch implementation of logsumexp for reference
|
||||||
|
def torch_logsumexp(logits):
|
||||||
|
"""PyTorch implementation of logsumexp over last dimension"""
|
||||||
|
return torch.logsumexp(logits, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
# Wrapper function for Triton logsumexp kernel
|
||||||
|
def triton_logsumexp(logits):
|
||||||
|
"""Triton implementation of logsumexp over last dimension"""
|
||||||
|
B, S, V = logits.shape # pylint: disable=invalid-name
|
||||||
|
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
|
||||||
|
|
||||||
|
grid = (B * S,)
|
||||||
|
logsumexp_kernel[grid](
|
||||||
|
logits.contiguous(),
|
||||||
|
output,
|
||||||
|
B,
|
||||||
|
S,
|
||||||
|
V,
|
||||||
|
logits.stride(0),
|
||||||
|
logits.stride(1),
|
||||||
|
logits.stride(2),
|
||||||
|
output.stride(0),
|
||||||
|
output.stride(1),
|
||||||
|
min(1024, triton.next_power_of_2(V)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class TritonLogSumExp(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Wrap a custom autograd function to use the Triton logsumexp for gradient testing
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, logits):
|
||||||
|
B, S, V = logits.shape # pylint: disable=invalid-name
|
||||||
|
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
|
||||||
|
|
||||||
|
# Save inputs for backward pass
|
||||||
|
ctx.save_for_backward(logits)
|
||||||
|
ctx.shape = logits.shape
|
||||||
|
|
||||||
|
grid = (B * S,)
|
||||||
|
logsumexp_kernel[grid](
|
||||||
|
logits.contiguous(),
|
||||||
|
output,
|
||||||
|
B,
|
||||||
|
S,
|
||||||
|
V,
|
||||||
|
logits.stride(0),
|
||||||
|
logits.stride(1),
|
||||||
|
logits.stride(2),
|
||||||
|
output.stride(0),
|
||||||
|
output.stride(1),
|
||||||
|
min(1024, triton.next_power_of_2(V)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
(logits,) = ctx.saved_tensors
|
||||||
|
|
||||||
|
# For logsumexp, the gradient is softmax(input) * grad_output
|
||||||
|
# First compute the logsumexp
|
||||||
|
lse = TritonLogSumExp.apply(logits)
|
||||||
|
|
||||||
|
# Compute softmax by exponentiating differences
|
||||||
|
softmax_output = torch.exp(logits - lse.unsqueeze(-1))
|
||||||
|
|
||||||
|
# Compute gradient of logsumexp by multiplying the softmax output by the gradient
|
||||||
|
grad_input = softmax_output * grad_output.unsqueeze(-1)
|
||||||
|
|
||||||
|
return grad_input
|
||||||
|
|
||||||
|
|
||||||
|
def test_logsumexp_values():
|
||||||
|
"""Test that the Triton logsumexp implementation matches PyTorch's"""
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
# Test with various input shapes
|
||||||
|
test_shapes = [
|
||||||
|
(2, 3, 10), # small vocab
|
||||||
|
(4, 5, 100), # medium vocab
|
||||||
|
(2, 2, 32000), # large vocab (typical for LLMs)
|
||||||
|
]
|
||||||
|
|
||||||
|
for shape in test_shapes:
|
||||||
|
# Create random input tensors
|
||||||
|
logits = torch.randn(shape, device="cuda", requires_grad=False)
|
||||||
|
|
||||||
|
# Compute logsumexp using both implementations
|
||||||
|
torch_result = torch_logsumexp(logits)
|
||||||
|
triton_result = triton_logsumexp(logits)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
max_diff = (torch_result - triton_result).abs().max().item()
|
||||||
|
print(f"Shape {shape}, Max diff: {max_diff}")
|
||||||
|
|
||||||
|
# Assert that the results are very close
|
||||||
|
assert max_diff < 1e-5, f"Results differ for shape {shape}: max diff {max_diff}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_logsumexp_edge_cases():
|
||||||
|
"""Test edge cases for numerical stability"""
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
# Case 1: Very large values that might cause overflow
|
||||||
|
logits_large = torch.ones(2, 3, 100, device="cuda") * 1000
|
||||||
|
|
||||||
|
# Case 2: Very small values that might cause underflow
|
||||||
|
logits_small = torch.ones(2, 3, 100, device="cuda") * -1000
|
||||||
|
|
||||||
|
# Case 3: Mix of large and small values
|
||||||
|
logits_mixed = torch.zeros(2, 3, 100, device="cuda")
|
||||||
|
logits_mixed[:, :, 0] = 1000 # One very large value
|
||||||
|
|
||||||
|
# Case 4: All identical values
|
||||||
|
logits_identical = torch.ones(2, 3, 100, device="cuda") * 5
|
||||||
|
|
||||||
|
# Case 5: Extreme values with NaN check
|
||||||
|
logits_extreme = torch.cat(
|
||||||
|
[
|
||||||
|
torch.full((1, 3, 50), 1e10, device="cuda"),
|
||||||
|
torch.full((1, 3, 50), -1e10, device="cuda"),
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, logits in enumerate(
|
||||||
|
[logits_large, logits_small, logits_mixed, logits_identical, logits_extreme]
|
||||||
|
):
|
||||||
|
# Compute logsumexp using both implementations
|
||||||
|
torch_result = torch_logsumexp(logits)
|
||||||
|
triton_result = triton_logsumexp(logits)
|
||||||
|
|
||||||
|
# Check for NaNs
|
||||||
|
assert not torch.isnan(
|
||||||
|
torch_result
|
||||||
|
).any(), f"PyTorch produced NaNs for case {i+1}"
|
||||||
|
assert not torch.isnan(
|
||||||
|
triton_result
|
||||||
|
).any(), f"Triton produced NaNs for case {i+1}"
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
max_diff = (torch_result - triton_result).abs().max().item()
|
||||||
|
print(f"Edge case {i+1}, Max diff: {max_diff}")
|
||||||
|
|
||||||
|
# For very extreme values, allow a bit more tolerance
|
||||||
|
if i == 4: # extreme case
|
||||||
|
assert max_diff < 1e-2, f"Results differ too much for edge case {i+1}"
|
||||||
|
else:
|
||||||
|
assert max_diff < 1e-5, f"Results differ too much for edge case {i+1}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_logsumexp_gradients():
|
||||||
|
"""Test that the gradients of Triton logsumexp match PyTorch's"""
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
# Create input tensors with gradients enabled
|
||||||
|
shapes = [(2, 3, 10), (4, 5, 100)]
|
||||||
|
|
||||||
|
for shape in shapes:
|
||||||
|
# Create two identical tensors for PyTorch and Triton
|
||||||
|
logits_torch = torch.randn(shape, device="cuda", requires_grad=True)
|
||||||
|
logits_triton = logits_torch.clone().detach().requires_grad_(True)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
torch_output = torch_logsumexp(logits_torch)
|
||||||
|
triton_output = TritonLogSumExp.apply(logits_triton)
|
||||||
|
|
||||||
|
# Compare forward pass values
|
||||||
|
max_diff_forward = (torch_output - triton_output).abs().max().item()
|
||||||
|
assert max_diff_forward < 1e-5, f"Forward pass values differ for shape {shape}"
|
||||||
|
|
||||||
|
# Create random gradient
|
||||||
|
grad_output = torch.randn_like(torch_output)
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
torch_output.backward(grad_output)
|
||||||
|
triton_output.backward(grad_output)
|
||||||
|
|
||||||
|
# Compare gradients
|
||||||
|
max_diff_grad = (logits_torch.grad - logits_triton.grad).abs().max().item()
|
||||||
|
print(f"Shape {shape}, Max gradient diff: {max_diff_grad}")
|
||||||
|
|
||||||
|
# Assert that gradients are very close
|
||||||
|
assert (
|
||||||
|
max_diff_grad < 1e-5
|
||||||
|
), f"Gradients differ for shape {shape}: max diff {max_diff_grad}"
|
||||||
Reference in New Issue
Block a user