205 lines
6.4 KiB
Python
205 lines
6.4 KiB
Python
"""
|
|
sanity checks on logsumexp kernel validity
|
|
"""
|
|
import torch
|
|
import triton
|
|
|
|
from axolotl.integrations.kd.topk_logprob.logsumexp import logsumexp_kernel
|
|
|
|
|
|
# PyTorch implementation of logsumexp for reference
|
|
def torch_logsumexp(logits):
|
|
"""PyTorch implementation of logsumexp over last dimension"""
|
|
return torch.logsumexp(logits, dim=-1)
|
|
|
|
|
|
# Wrapper function for Triton logsumexp kernel
|
|
def triton_logsumexp(logits):
|
|
"""Triton implementation of logsumexp over last dimension"""
|
|
B, S, V = logits.shape # pylint: disable=invalid-name
|
|
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
|
|
|
|
grid = (B * S,)
|
|
logsumexp_kernel[grid](
|
|
logits.contiguous(),
|
|
output,
|
|
B,
|
|
S,
|
|
V,
|
|
logits.stride(0),
|
|
logits.stride(1),
|
|
logits.stride(2),
|
|
output.stride(0),
|
|
output.stride(1),
|
|
min(1024, triton.next_power_of_2(V)),
|
|
)
|
|
|
|
return output
|
|
|
|
|
|
class TritonLogSumExp(torch.autograd.Function):
|
|
"""
|
|
Wrap a custom autograd function to use the Triton logsumexp for gradient testing
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, logits):
|
|
B, S, V = logits.shape # pylint: disable=invalid-name
|
|
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
|
|
|
|
# Save inputs for backward pass
|
|
ctx.save_for_backward(logits)
|
|
ctx.shape = logits.shape
|
|
|
|
grid = (B * S,)
|
|
logsumexp_kernel[grid](
|
|
logits.contiguous(),
|
|
output,
|
|
B,
|
|
S,
|
|
V,
|
|
logits.stride(0),
|
|
logits.stride(1),
|
|
logits.stride(2),
|
|
output.stride(0),
|
|
output.stride(1),
|
|
min(1024, triton.next_power_of_2(V)),
|
|
)
|
|
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
(logits,) = ctx.saved_tensors
|
|
|
|
# For logsumexp, the gradient is softmax(input) * grad_output
|
|
# First compute the logsumexp
|
|
lse = TritonLogSumExp.apply(logits)
|
|
|
|
# Compute softmax by exponentiating differences
|
|
softmax_output = torch.exp(logits - lse.unsqueeze(-1))
|
|
|
|
# Compute gradient of logsumexp by multiplying the softmax output by the gradient
|
|
grad_input = softmax_output * grad_output.unsqueeze(-1)
|
|
|
|
return grad_input
|
|
|
|
|
|
def test_logsumexp_values():
|
|
"""Test that the Triton logsumexp implementation matches PyTorch's"""
|
|
# Set random seed for reproducibility
|
|
torch.manual_seed(42)
|
|
|
|
# Test with various input shapes
|
|
test_shapes = [
|
|
(2, 3, 10), # small vocab
|
|
(4, 5, 100), # medium vocab
|
|
(2, 2, 32000), # large vocab (typical for LLMs)
|
|
]
|
|
|
|
for shape in test_shapes:
|
|
# Create random input tensors
|
|
logits = torch.randn(shape, device="cuda", requires_grad=False)
|
|
|
|
# Compute logsumexp using both implementations
|
|
torch_result = torch_logsumexp(logits)
|
|
triton_result = triton_logsumexp(logits)
|
|
|
|
# Compare results
|
|
max_diff = (torch_result - triton_result).abs().max().item()
|
|
print(f"Shape {shape}, Max diff: {max_diff}")
|
|
|
|
# Assert that the results are very close
|
|
assert max_diff < 1e-5, f"Results differ for shape {shape}: max diff {max_diff}"
|
|
|
|
|
|
def test_logsumexp_edge_cases():
|
|
"""Test edge cases for numerical stability"""
|
|
# Set random seed for reproducibility
|
|
torch.manual_seed(42)
|
|
|
|
# Case 1: Very large values that might cause overflow
|
|
logits_large = torch.ones(2, 3, 100, device="cuda") * 1000
|
|
|
|
# Case 2: Very small values that might cause underflow
|
|
logits_small = torch.ones(2, 3, 100, device="cuda") * -1000
|
|
|
|
# Case 3: Mix of large and small values
|
|
logits_mixed = torch.zeros(2, 3, 100, device="cuda")
|
|
logits_mixed[:, :, 0] = 1000 # One very large value
|
|
|
|
# Case 4: All identical values
|
|
logits_identical = torch.ones(2, 3, 100, device="cuda") * 5
|
|
|
|
# Case 5: Extreme values with NaN check
|
|
logits_extreme = torch.cat(
|
|
[
|
|
torch.full((1, 3, 50), 1e10, device="cuda"),
|
|
torch.full((1, 3, 50), -1e10, device="cuda"),
|
|
],
|
|
dim=0,
|
|
)
|
|
|
|
for i, logits in enumerate(
|
|
[logits_large, logits_small, logits_mixed, logits_identical, logits_extreme]
|
|
):
|
|
# Compute logsumexp using both implementations
|
|
torch_result = torch_logsumexp(logits)
|
|
triton_result = triton_logsumexp(logits)
|
|
|
|
# Check for NaNs
|
|
assert not torch.isnan(
|
|
torch_result
|
|
).any(), f"PyTorch produced NaNs for case {i+1}"
|
|
assert not torch.isnan(
|
|
triton_result
|
|
).any(), f"Triton produced NaNs for case {i+1}"
|
|
|
|
# Compare results
|
|
max_diff = (torch_result - triton_result).abs().max().item()
|
|
print(f"Edge case {i+1}, Max diff: {max_diff}")
|
|
|
|
# For very extreme values, allow a bit more tolerance
|
|
if i == 4: # extreme case
|
|
assert max_diff < 1e-2, f"Results differ too much for edge case {i+1}"
|
|
else:
|
|
assert max_diff < 1e-5, f"Results differ too much for edge case {i+1}"
|
|
|
|
|
|
def test_logsumexp_gradients():
|
|
"""Test that the gradients of Triton logsumexp match PyTorch's"""
|
|
# Set random seed for reproducibility
|
|
torch.manual_seed(42)
|
|
|
|
# Create input tensors with gradients enabled
|
|
shapes = [(2, 3, 10), (4, 5, 100)]
|
|
|
|
for shape in shapes:
|
|
# Create two identical tensors for PyTorch and Triton
|
|
logits_torch = torch.randn(shape, device="cuda", requires_grad=True)
|
|
logits_triton = logits_torch.clone().detach().requires_grad_(True)
|
|
|
|
# Forward pass
|
|
torch_output = torch_logsumexp(logits_torch)
|
|
triton_output = TritonLogSumExp.apply(logits_triton)
|
|
|
|
# Compare forward pass values
|
|
max_diff_forward = (torch_output - triton_output).abs().max().item()
|
|
assert max_diff_forward < 1e-5, f"Forward pass values differ for shape {shape}"
|
|
|
|
# Create random gradient
|
|
grad_output = torch.randn_like(torch_output)
|
|
|
|
# Backward pass
|
|
torch_output.backward(grad_output)
|
|
triton_output.backward(grad_output)
|
|
|
|
# Compare gradients
|
|
max_diff_grad = (logits_torch.grad - logits_triton.grad).abs().max().item()
|
|
print(f"Shape {shape}, Max gradient diff: {max_diff_grad}")
|
|
|
|
# Assert that gradients are very close
|
|
assert (
|
|
max_diff_grad < 1e-5
|
|
), f"Gradients differ for shape {shape}: max diff {max_diff_grad}"
|