164 lines
5.1 KiB
Python
164 lines
5.1 KiB
Python
"""
|
|
sanity checks on kl loss and gradients
|
|
"""
|
|
import torch
|
|
|
|
# Import both implementations
|
|
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
|
|
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
|
|
|
|
|
|
def test_kl_loss_gradient():
|
|
"""Test that the gradient of the Triton implementation matches the eager implementation."""
|
|
|
|
# Set the random seed for reproducibility
|
|
torch.manual_seed(42)
|
|
|
|
# Create random inputs
|
|
batch_size = 2
|
|
seq_len = 3
|
|
vocab_size = 100
|
|
top_k = 5
|
|
|
|
# Generate random student logits
|
|
student_logits = torch.randn(
|
|
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
|
|
)
|
|
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
|
|
|
# Generate random target token IDs, ensuring they're valid indices
|
|
# pylint: disable=duplicate-code
|
|
target_token_ids = torch.randint(
|
|
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
|
)
|
|
|
|
# Generate random target logprobs (before normalization)
|
|
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
|
|
|
# Normalize the target logprobs to ensure they form a valid distribution
|
|
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
|
|
|
# Create a random mask with some tokens masked out
|
|
target_mask = torch.randint(
|
|
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
|
).float()
|
|
|
|
# Additional parameters
|
|
num_items_in_batch = batch_size * seq_len
|
|
kd_temperature = 1.0
|
|
top_k_before_softmax = 0 # Test both modes
|
|
|
|
# Compute the loss and gradients with eager implementation
|
|
loss_eager = eager_loss(
|
|
student_logits,
|
|
target_token_ids,
|
|
target_logprobs,
|
|
target_mask,
|
|
num_items_in_batch,
|
|
kd_temperature,
|
|
top_k_before_softmax,
|
|
)
|
|
loss_eager.backward()
|
|
grad_eager = student_logits.grad.clone()
|
|
|
|
# Reset gradients
|
|
student_logits.grad.zero_()
|
|
|
|
# Compute the loss and gradients with Triton implementation
|
|
loss_triton = triton_loss(
|
|
student_logits_triton,
|
|
target_token_ids,
|
|
target_logprobs,
|
|
target_mask,
|
|
num_items_in_batch,
|
|
kd_temperature,
|
|
top_k_before_softmax,
|
|
)
|
|
loss_triton.backward()
|
|
grad_triton = student_logits_triton.grad.clone()
|
|
|
|
# Compare loss values
|
|
print(f"Eager loss: {loss_eager.item()}")
|
|
print(f"Triton loss: {loss_triton.item()}")
|
|
loss_diff = abs(loss_eager.item() - loss_triton.item())
|
|
print(f"Loss difference: {loss_diff}")
|
|
assert loss_diff < 1e-5, "Loss values differ significantly!"
|
|
|
|
# Compare gradients
|
|
grad_diff = (grad_eager - grad_triton).abs().max().item()
|
|
print(f"Max gradient difference: {grad_diff}")
|
|
|
|
# Print some sample gradients
|
|
sample_idx = (0, 0, 0) # (batch, seq, vocab)
|
|
print(f"Sample eager gradient: {grad_eager[sample_idx].item()}")
|
|
print(f"Sample triton gradient: {grad_triton[sample_idx].item()}")
|
|
|
|
# Compute relative difference for non-zero gradients
|
|
mask = grad_eager.abs() > 1e-10
|
|
if mask.sum() > 0:
|
|
rel_diff = (
|
|
(
|
|
(grad_eager[mask] - grad_triton[mask]).abs()
|
|
/ (grad_eager[mask].abs() + 1e-10)
|
|
)
|
|
.max()
|
|
.item()
|
|
)
|
|
print(f"Max relative gradient difference: {rel_diff}")
|
|
assert rel_diff < 1e-3, "Gradients differ significantly!"
|
|
|
|
# Also test top_k_before_softmax = 1 mode
|
|
top_k_before_softmax = 1
|
|
|
|
# Reset the gradients
|
|
student_logits = torch.randn(
|
|
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
|
|
)
|
|
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
|
|
|
# Compute the loss and gradients with eager implementation
|
|
loss_eager = eager_loss(
|
|
student_logits,
|
|
target_token_ids,
|
|
target_logprobs,
|
|
target_mask,
|
|
num_items_in_batch,
|
|
kd_temperature,
|
|
top_k_before_softmax,
|
|
)
|
|
loss_eager.backward()
|
|
grad_eager = student_logits.grad.clone()
|
|
|
|
# Compute the loss and gradients with Triton implementation
|
|
loss_triton = triton_loss(
|
|
student_logits_triton,
|
|
target_token_ids,
|
|
target_logprobs,
|
|
target_mask,
|
|
num_items_in_batch,
|
|
kd_temperature,
|
|
top_k_before_softmax,
|
|
)
|
|
loss_triton.backward()
|
|
grad_triton = student_logits_triton.grad.clone()
|
|
|
|
# Compare gradients for top_k_before_softmax = 1
|
|
grad_diff = (grad_eager - grad_triton).abs().max().item()
|
|
print("\nWith top_k_before_softmax=1:")
|
|
print(f"Max gradient difference: {grad_diff}")
|
|
|
|
# Compute relative difference for non-zero gradients
|
|
mask = grad_eager.abs() > 1e-10
|
|
if mask.sum() > 0:
|
|
rel_diff = (
|
|
(
|
|
(grad_eager[mask] - grad_triton[mask]).abs()
|
|
/ (grad_eager[mask].abs() + 1e-10)
|
|
)
|
|
.max()
|
|
.item()
|
|
)
|
|
assert (
|
|
rel_diff < 1e-3
|
|
), f"Gradients differ significantly, Max relative gradient difference: {rel_diff}"
|