Compare commits
10 Commits
quantize-p
...
topk-logpr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68e97d032a | ||
|
|
23f029a89c | ||
|
|
afbb44f08b | ||
|
|
d753ead033 | ||
|
|
c011405117 | ||
|
|
a2e52a29e9 | ||
|
|
e82268e580 | ||
|
|
75e1480c10 | ||
|
|
45e1548d59 | ||
|
|
165088e7c1 |
391
src/axolotl/integrations/kd/topk_logprob/bench_kl.py
Normal file
391
src/axolotl/integrations/kd/topk_logprob/bench_kl.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
benchmark utility helper for benchmarking the KL divergence triton kernel
|
||||
"""
|
||||
import gc
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch.utils.benchmark import Timer
|
||||
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
|
||||
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
def benchmark_kl_div_loss_with_backward():
|
||||
# Test configurations
|
||||
batch_sizes = [1, 4]
|
||||
seq_lens = [64, 512, 2048, 4096, 8192]
|
||||
vocab_size = 32000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
for seq_len in seq_lens:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create tensors with gradients
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the two implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Define functions for timing that include both forward and backward passes
|
||||
def run_reference():
|
||||
# Forward pass
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
# Backward pass
|
||||
loss_ref.backward()
|
||||
|
||||
def run_triton():
|
||||
# Forward pass
|
||||
# pylint: disable=duplicate-code
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
# Backward pass
|
||||
loss_triton.backward()
|
||||
|
||||
# Benchmark reference implementation (forward + backward)
|
||||
t0 = Timer(
|
||||
stmt="run_reference()",
|
||||
globals={
|
||||
"run_reference": run_reference,
|
||||
},
|
||||
)
|
||||
# Reset gradients before timing
|
||||
student_logits_ref.grad = None
|
||||
ref_time = t0.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Benchmark Triton implementation (forward + backward)
|
||||
t1 = Timer(
|
||||
stmt="run_triton()",
|
||||
globals={
|
||||
"run_triton": run_triton,
|
||||
},
|
||||
)
|
||||
# Reset gradients before timing
|
||||
student_logits_triton.grad = None
|
||||
triton_time = t1.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Compute speedup
|
||||
speedup = ref_time / triton_time if triton_time > 0 else float("inf")
|
||||
|
||||
# Store results
|
||||
results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"seq_len": seq_len,
|
||||
"reference_time_ms": ref_time,
|
||||
"triton_time_ms": triton_time,
|
||||
"speedup": speedup,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
|
||||
print(f" Reference time (fwd+bwd): {ref_time:.2f} ms")
|
||||
print(f" Triton time (fwd+bwd): {triton_time:.2f} ms")
|
||||
print(f" Speedup: {speedup:.2f}x")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def benchmark_forward_backward_separately():
|
||||
"""
|
||||
Benchmark forward and backward passes separately to identify where the speedup comes from.
|
||||
"""
|
||||
# Test configurations
|
||||
batch_sizes = [1, 4, 8]
|
||||
seq_lens = [64, 512, 2048]
|
||||
vocab_size = 32000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
detailed_results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
for seq_len in seq_lens:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create tensors with gradients
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the two implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Forward-only reference
|
||||
def run_reference_forward():
|
||||
with torch.no_grad():
|
||||
return eager_loss(
|
||||
student_logits_ref,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
|
||||
# Forward-only triton
|
||||
def run_triton_forward():
|
||||
with torch.no_grad():
|
||||
return triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
|
||||
# Benchmark forward pass only
|
||||
|
||||
t0_fwd = Timer(
|
||||
stmt="run_reference_forward()",
|
||||
globals={
|
||||
"run_reference_forward": run_reference_forward,
|
||||
},
|
||||
)
|
||||
ref_fwd_time = t0_fwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
t1_fwd = Timer(
|
||||
stmt="run_triton_forward()",
|
||||
globals={
|
||||
"run_triton_forward": run_triton_forward,
|
||||
},
|
||||
)
|
||||
triton_fwd_time = t1_fwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Pre-compute losses for backward pass benchmarking
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
|
||||
# Backward-only reference
|
||||
def run_reference_backward():
|
||||
student_logits_ref.grad = None
|
||||
loss_ref.backward()
|
||||
|
||||
# Backward-only triton
|
||||
def run_triton_backward():
|
||||
student_logits_triton.grad = None
|
||||
loss_triton.backward()
|
||||
|
||||
# Benchmark backward pass only
|
||||
t0_bwd = Timer(
|
||||
stmt="run_reference_backward()",
|
||||
globals={
|
||||
"run_reference_backward": run_reference_backward,
|
||||
},
|
||||
)
|
||||
ref_bwd_time = t0_bwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
t1_bwd = Timer(
|
||||
stmt="run_triton_backward()",
|
||||
globals={
|
||||
"run_triton_backward": run_triton_backward,
|
||||
},
|
||||
)
|
||||
triton_bwd_time = t1_bwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Compute speedups
|
||||
fwd_speedup = (
|
||||
ref_fwd_time / triton_fwd_time if triton_fwd_time > 0 else float("inf")
|
||||
)
|
||||
bwd_speedup = (
|
||||
ref_bwd_time / triton_bwd_time if triton_bwd_time > 0 else float("inf")
|
||||
)
|
||||
total_ref_time = ref_fwd_time + ref_bwd_time
|
||||
total_triton_time = triton_fwd_time + triton_bwd_time
|
||||
total_speedup = (
|
||||
total_ref_time / total_triton_time
|
||||
if total_triton_time > 0
|
||||
else float("inf")
|
||||
)
|
||||
|
||||
# Store results
|
||||
detailed_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"seq_len": seq_len,
|
||||
"ref_forward_ms": ref_fwd_time,
|
||||
"triton_forward_ms": triton_fwd_time,
|
||||
"forward_speedup": fwd_speedup,
|
||||
"ref_backward_ms": ref_bwd_time,
|
||||
"triton_backward_ms": triton_bwd_time,
|
||||
"backward_speedup": bwd_speedup,
|
||||
"total_ref_ms": total_ref_time,
|
||||
"total_triton_ms": total_triton_time,
|
||||
"total_speedup": total_speedup,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
|
||||
print(
|
||||
f" Forward: Reference={ref_fwd_time:.2f}ms, Triton={triton_fwd_time:.2f}ms, Speedup={fwd_speedup:.2f}x"
|
||||
)
|
||||
print(
|
||||
f" Backward: Reference={ref_bwd_time:.2f}ms, Triton={triton_bwd_time:.2f}ms, Speedup={bwd_speedup:.2f}x"
|
||||
)
|
||||
print(
|
||||
f" Total: Reference={total_ref_time:.2f}ms, Triton={total_triton_time:.2f}ms, Speedup={total_speedup:.2f}x"
|
||||
)
|
||||
|
||||
return detailed_results
|
||||
|
||||
|
||||
def benchmark_memory_usage_with_backward():
|
||||
# Test configurations
|
||||
batch_sizes = [1, 2]
|
||||
seq_len = 8192
|
||||
vocab_size = 128000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
mem_results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Measure PyTorch memory usage (forward + backward)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_ref.backward()
|
||||
torch.cuda.synchronize()
|
||||
pytorch_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
# Measure Triton memory usage (forward + backward)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_triton.backward()
|
||||
torch.cuda.synchronize()
|
||||
triton_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
# Measure Triton memory usage with different chunk sizes (forward + backward)
|
||||
for n_chunks in [1, 2, 4, 8]:
|
||||
student_logits_chunk = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_chunk = triton_loss(
|
||||
student_logits_chunk,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
loss_chunk.backward()
|
||||
torch.cuda.synchronize()
|
||||
chunk_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": f"Triton (chunks={n_chunks})",
|
||||
"memory_mb": chunk_mem,
|
||||
}
|
||||
)
|
||||
|
||||
# Store results
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": "PyTorch",
|
||||
"memory_mb": pytorch_mem,
|
||||
}
|
||||
)
|
||||
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": "Triton (default)",
|
||||
"memory_mb": triton_mem,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size} (with backward pass)")
|
||||
print(f" PyTorch memory: {pytorch_mem:.2f} MB")
|
||||
print(f" Triton memory: {triton_mem:.2f} MB")
|
||||
print(f" Memory reduction: {(1 - triton_mem/pytorch_mem)*100:.2f}%")
|
||||
|
||||
return mem_results
|
||||
|
||||
|
||||
def main():
|
||||
print("Running benchmarks with forward and backward passes...")
|
||||
benchmark_kl_div_loss_with_backward()
|
||||
clean()
|
||||
|
||||
print("\nRunning detailed forward/backward benchmarks...")
|
||||
# benchmark_forward_backward_separately()
|
||||
# clean()
|
||||
|
||||
print("\nRunning memory usage benchmarks with backward passes...")
|
||||
benchmark_memory_usage_with_backward()
|
||||
clean()
|
||||
|
||||
|
||||
def clean():
|
||||
for _ in range(5):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
750
src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py
Normal file
750
src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py
Normal file
@@ -0,0 +1,750 @@
|
||||
"""
|
||||
Optimized Triton kernel for KL divergence loss between teacher and student models.
|
||||
"""
|
||||
# pylint: disable=invalid-name,unused-argument
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_logsumexp_logprobs_kernel(
|
||||
student_logits_ptr, # Input logits in original dtype
|
||||
student_logprobs_ptr, # Output logprobs (float32)
|
||||
token_ids_ptr, # Token IDs for top-k
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
K, # batch size, seq len, vocab size, top-k
|
||||
temperature,
|
||||
stride_l_b,
|
||||
stride_l_s,
|
||||
stride_l_v,
|
||||
stride_lp_b,
|
||||
stride_lp_s,
|
||||
stride_lp_k,
|
||||
stride_t_b,
|
||||
stride_t_s,
|
||||
stride_t_k,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Fused kernel that computes logsumexp and logprobs for topk tokens.
|
||||
All computations are done in float32 for numerical stability.
|
||||
"""
|
||||
# Program ID
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Compute logsumexp over the vocabulary
|
||||
max_val = -float("inf")
|
||||
|
||||
# Phase 1: Find max value across vocabulary
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
# Create block indices and mask
|
||||
block_size = min(BLOCK_SIZE, V - v_offset)
|
||||
block_idx = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block_idx < block_size
|
||||
|
||||
# Load logits block and convert to float32 in-place
|
||||
ptrs = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ (v_offset + block_idx) * stride_l_v
|
||||
)
|
||||
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
block_logits = block_logits / temperature
|
||||
|
||||
# Update max value
|
||||
block_max = tl.max(block_logits, axis=0)
|
||||
max_val = tl.maximum(max_val, block_max)
|
||||
|
||||
# Phase 2: Compute sum of exp(logits - max_val)
|
||||
sum_exp = 0.0
|
||||
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
# Create block indices and mask
|
||||
block_size = min(BLOCK_SIZE, V - v_offset)
|
||||
block_idx = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block_idx < block_size
|
||||
|
||||
# Load logits block and convert to float32 in-place
|
||||
ptrs = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ (v_offset + block_idx) * stride_l_v
|
||||
)
|
||||
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
block_logits = block_logits / temperature
|
||||
|
||||
# Compute exp(logits - max_val) and add to sum
|
||||
block_exp = tl.exp(block_logits - max_val)
|
||||
sum_exp += tl.sum(block_exp * mask, axis=0)
|
||||
|
||||
# Compute final logsumexp
|
||||
logsumexp = max_val + tl.log(sum_exp)
|
||||
|
||||
# Phase 3: Compute and store logprobs for the top-k tokens
|
||||
token_ids_base = token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||
logprobs_base = (
|
||||
student_logprobs_ptr + batch_idx * stride_lp_b + seq_idx * stride_lp_s
|
||||
)
|
||||
|
||||
for k in range(K):
|
||||
# Load token ID for position k
|
||||
token_id = tl.load(token_ids_base + k * stride_t_k)
|
||||
|
||||
# Load the corresponding logit and convert to float32
|
||||
token_logit_ptr = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ token_id * stride_l_v
|
||||
)
|
||||
token_logit = tl.load(token_logit_ptr).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
token_logit = token_logit / temperature
|
||||
|
||||
# Compute logprob directly: logit - logsumexp
|
||||
token_logprob = token_logit - logsumexp
|
||||
|
||||
# Store the result
|
||||
tl.store(logprobs_base + k * stride_lp_k, token_logprob)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def grad_softmax_kernel(
|
||||
grad_student_logits_ptr,
|
||||
target_token_ids_ptr,
|
||||
teacher_probs_ptr,
|
||||
student_probs_ptr,
|
||||
mask_ptr,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
K, # batch size, seq len, vocab size, top-k
|
||||
scale,
|
||||
stride_gl_b,
|
||||
stride_gl_s,
|
||||
stride_gl_v,
|
||||
stride_t_b,
|
||||
stride_t_s,
|
||||
stride_t_k,
|
||||
stride_p_b,
|
||||
stride_p_s,
|
||||
stride_p_k,
|
||||
stride_sp_b,
|
||||
stride_sp_s,
|
||||
stride_sp_k,
|
||||
stride_m_b,
|
||||
stride_m_s,
|
||||
stride_m_k,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Program ID
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Base pointers for this (batch, seq) pair
|
||||
grad_logits_base = (
|
||||
grad_student_logits_ptr + batch_idx * stride_gl_b + seq_idx * stride_gl_s
|
||||
)
|
||||
token_ids_base = (
|
||||
target_token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||
)
|
||||
teacher_probs_base = (
|
||||
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
|
||||
)
|
||||
student_probs_base = (
|
||||
student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s
|
||||
)
|
||||
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
|
||||
|
||||
# Process each teacher probability one at a time, computing all gradients for it
|
||||
for k in range(0, K):
|
||||
# Load data for current position k
|
||||
teacher_prob = tl.load(teacher_probs_base + k * stride_p_k)
|
||||
student_prob_k = tl.load(student_probs_base + k * stride_sp_k)
|
||||
mask_val = tl.load(mask_base + k * stride_m_k)
|
||||
|
||||
# Precompute the self-influence term (multiplied by scale)
|
||||
self_term = teacher_prob * (1.0 - student_prob_k) * scale
|
||||
|
||||
# Calculate gradient contributions for all positions j
|
||||
for j in range(0, K):
|
||||
token_id_j = tl.load(token_ids_base + j * stride_t_k)
|
||||
student_prob_j = tl.load(student_probs_base + j * stride_sp_k)
|
||||
mask_j = tl.load(mask_base + j * stride_m_k)
|
||||
|
||||
# Calculate the masking factor
|
||||
combined_mask = mask_val * mask_j
|
||||
|
||||
# Determine if this is a diagonal or off-diagonal term
|
||||
is_k_equals_j = tl.where(k == j, 1.0, 0.0)
|
||||
|
||||
# Compute the gradient contribution
|
||||
# For diagonal (k==j): -teacher_prob * (1-student_prob_k) * scale * mask
|
||||
# For off-diagonal: -(-teacher_prob * student_prob_j) * scale * mask
|
||||
grad_contribution = (
|
||||
-(
|
||||
self_term * is_k_equals_j
|
||||
- teacher_prob * student_prob_j * scale * (1.0 - is_k_equals_j)
|
||||
)
|
||||
* combined_mask
|
||||
)
|
||||
|
||||
# Atomically update the gradient for this token
|
||||
tl.atomic_add(
|
||||
grad_logits_base + token_id_j * stride_gl_v, grad_contribution
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def grad_topk_softmax_kernel(
|
||||
grad_student_logits_ptr,
|
||||
student_logits_ptr,
|
||||
target_token_ids_ptr,
|
||||
teacher_probs_ptr,
|
||||
student_probs_ptr,
|
||||
mask_ptr,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
K, # batch size, seq len, vocab size, top-k
|
||||
scale,
|
||||
stride_gl_b,
|
||||
stride_gl_s,
|
||||
stride_gl_v,
|
||||
stride_l_b,
|
||||
stride_l_s,
|
||||
stride_l_v,
|
||||
stride_t_b,
|
||||
stride_t_s,
|
||||
stride_t_k,
|
||||
stride_p_b,
|
||||
stride_p_s,
|
||||
stride_p_k,
|
||||
stride_sp_b,
|
||||
stride_sp_s,
|
||||
stride_sp_k,
|
||||
stride_m_b,
|
||||
stride_m_s,
|
||||
stride_m_k,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Program ID
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Base pointers for this (batch, seq) pair
|
||||
grad_logits_base = (
|
||||
grad_student_logits_ptr + batch_idx * stride_gl_b + seq_idx * stride_gl_s
|
||||
)
|
||||
# logits_base = student_logits_ptr + batch_idx * stride_l_b + seq_idx * stride_l_s
|
||||
token_ids_base = (
|
||||
target_token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||
)
|
||||
teacher_probs_base = (
|
||||
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
|
||||
)
|
||||
student_probs_base = (
|
||||
student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s
|
||||
)
|
||||
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
|
||||
|
||||
# Load all token IDs, probs and masks for this position
|
||||
token_ids = tl.zeros([K], dtype=tl.int32)
|
||||
teacher_probs = tl.zeros([K], dtype=tl.float32)
|
||||
student_probs = tl.zeros([K], dtype=tl.float32)
|
||||
masks = tl.zeros([K], dtype=tl.float32)
|
||||
|
||||
for k in range(K):
|
||||
token_ids[k] = tl.load(token_ids_base + k * stride_t_k)
|
||||
teacher_probs[k] = tl.load(teacher_probs_base + k * stride_p_k)
|
||||
student_probs[k] = tl.load(student_probs_base + k * stride_sp_k)
|
||||
masks[k] = tl.load(mask_base + k * stride_m_k)
|
||||
|
||||
# Process gradients for all tokens in this position
|
||||
for k in range(K):
|
||||
# token_id = token_ids[k]
|
||||
mask_k = masks[k]
|
||||
|
||||
# Skip computation if mask is zero by multiplying gradient by mask
|
||||
for j in range(K):
|
||||
other_token_id = token_ids[j]
|
||||
mask_j = masks[j]
|
||||
combined_mask = mask_k * mask_j
|
||||
|
||||
# Compute gradient differently for diagonal vs off-diagonal entries
|
||||
# Using * 1.0 to convert boolean to float
|
||||
is_diagonal = tl.where(j == k, 1.0, 0.0)
|
||||
|
||||
# Self influence: gradient = teacher_prob * (1 - student_prob)
|
||||
self_grad = teacher_probs[k] * (1.0 - student_probs[k]) * is_diagonal
|
||||
|
||||
# Cross influence: gradient = -teacher_prob[k] * student_prob[j]
|
||||
cross_grad = -teacher_probs[k] * student_probs[j] * (1.0 - is_diagonal)
|
||||
|
||||
# Combined gradient scaled by mask
|
||||
grad_val = (self_grad + cross_grad) * scale * combined_mask
|
||||
|
||||
tl.atomic_add(grad_logits_base + other_token_id * stride_gl_v, grad_val)
|
||||
|
||||
|
||||
# Triton-accelerated implementation of KL divergence loss for top-k tokens
|
||||
# Chunking helper functions for handling long sequences
|
||||
def chunk_tensor(
|
||||
tensor: torch.Tensor, max_seq_len: int
|
||||
) -> Tuple[torch.Tensor, Optional[int]]:
|
||||
"""Split a tensor along sequence dimension if needed."""
|
||||
_, seq_len, *__ = tensor.shape
|
||||
|
||||
if seq_len <= max_seq_len:
|
||||
return tensor, None
|
||||
|
||||
num_chunks = (seq_len + max_seq_len - 1) // max_seq_len
|
||||
chunks = []
|
||||
|
||||
for i in range(num_chunks):
|
||||
start_idx = i * max_seq_len
|
||||
end_idx = min((i + 1) * max_seq_len, seq_len)
|
||||
chunks.append(tensor[:, start_idx:end_idx, ...])
|
||||
|
||||
return chunks, num_chunks
|
||||
|
||||
|
||||
def merge_chunks(chunks: list, original_shape: torch.Size):
|
||||
"""Merge chunks back into a single tensor with original shape."""
|
||||
return torch.cat(chunks, dim=1)
|
||||
|
||||
|
||||
# Triton-accelerated implementation of KL divergence loss for top-k tokens
|
||||
class TopKKLDivergence(torch.autograd.Function):
|
||||
"""
|
||||
Autograd function for KL divergence loss between top-k logprobs
|
||||
with support for chunking to handle very long sequences.
|
||||
"""
|
||||
|
||||
# Max sequence length to process in a single kernel launch
|
||||
# This is a tunable parameter that might need adjustment based on GPU memory
|
||||
MAX_SEQ_LEN = 8192
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch=-1,
|
||||
kd_temperature=1.0,
|
||||
top_k_before_softmax=0,
|
||||
):
|
||||
"""
|
||||
Forward pass for KL divergence loss between top-k logprobs with chunking.
|
||||
"""
|
||||
# Only convert target_logprobs to float, leave student_logits as is
|
||||
target_logprobs = target_logprobs.float()
|
||||
|
||||
# Get dimensions
|
||||
batch_size, _, vocab_size = student_logits.shape
|
||||
_, teacher_seq_len, top_k = target_token_ids.shape
|
||||
|
||||
# Slice student logits to match teacher sequence length
|
||||
student_logits_for_kd = student_logits[:, :teacher_seq_len, :]
|
||||
|
||||
# Store original values for backward pass
|
||||
ctx.original_seq_len = teacher_seq_len
|
||||
ctx.original_dtype = student_logits.dtype
|
||||
|
||||
# Apply chunking for long sequences
|
||||
if teacher_seq_len > TopKKLDivergence.MAX_SEQ_LEN:
|
||||
# Chunk the inputs
|
||||
student_logits_chunks, num_chunks = chunk_tensor(
|
||||
student_logits_for_kd, TopKKLDivergence.MAX_SEQ_LEN
|
||||
)
|
||||
target_token_ids_chunks, _ = chunk_tensor(
|
||||
target_token_ids, TopKKLDivergence.MAX_SEQ_LEN
|
||||
)
|
||||
# target_logprobs_chunks, _ = chunk_tensor(
|
||||
# target_logprobs, TopKKLDivergence.MAX_SEQ_LEN
|
||||
# )
|
||||
# target_mask_chunks, _ = chunk_tensor(
|
||||
# target_mask, TopKKLDivergence.MAX_SEQ_LEN
|
||||
# )
|
||||
|
||||
# Process each chunk
|
||||
student_logprobs_chunks = []
|
||||
student_probs_chunks = []
|
||||
|
||||
for i in range(num_chunks):
|
||||
chunk_logits = student_logits_chunks[i]
|
||||
chunk_token_ids = target_token_ids_chunks[i]
|
||||
chunk_seq_len = chunk_logits.shape[1]
|
||||
|
||||
if top_k_before_softmax:
|
||||
# Apply temperature to student logits
|
||||
if kd_temperature != 1.0:
|
||||
chunk_logits = chunk_logits / kd_temperature
|
||||
|
||||
# Gather student logits for top-k tokens
|
||||
chunk_logits_topk = torch.gather(
|
||||
chunk_logits, dim=-1, index=chunk_token_ids
|
||||
)
|
||||
|
||||
# Compute softmax over gathered logits
|
||||
chunk_logprobs_topk = torch.log_softmax(chunk_logits_topk, dim=-1)
|
||||
chunk_probs_topk = torch.exp(chunk_logprobs_topk)
|
||||
else:
|
||||
# Allocate output tensor for logprobs directly (always in float32)
|
||||
chunk_logprobs_topk = torch.empty(
|
||||
(batch_size, chunk_seq_len, top_k),
|
||||
dtype=torch.float32,
|
||||
device=chunk_logits.device,
|
||||
)
|
||||
|
||||
# Launch fused kernel directly
|
||||
grid = (batch_size * chunk_seq_len,)
|
||||
fused_logsumexp_logprobs_kernel[grid](
|
||||
chunk_logits.contiguous(),
|
||||
chunk_logprobs_topk,
|
||||
chunk_token_ids.contiguous(),
|
||||
batch_size,
|
||||
chunk_seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
kd_temperature,
|
||||
chunk_logits.stride(0),
|
||||
chunk_logits.stride(1),
|
||||
chunk_logits.stride(2),
|
||||
chunk_logprobs_topk.stride(0),
|
||||
chunk_logprobs_topk.stride(1),
|
||||
chunk_logprobs_topk.stride(2),
|
||||
chunk_token_ids.stride(0),
|
||||
chunk_token_ids.stride(1),
|
||||
chunk_token_ids.stride(2),
|
||||
min(1024, triton.next_power_of_2(vocab_size)),
|
||||
)
|
||||
|
||||
# Calculate probs from logprobs
|
||||
chunk_probs_topk = torch.exp(chunk_logprobs_topk)
|
||||
|
||||
# Store results
|
||||
student_logprobs_chunks.append(chunk_logprobs_topk)
|
||||
student_probs_chunks.append(chunk_probs_topk)
|
||||
|
||||
# Merge results
|
||||
student_logprobs_topk = torch.cat(student_logprobs_chunks, dim=1)
|
||||
student_probs_topk = torch.cat(student_probs_chunks, dim=1)
|
||||
|
||||
# Save chunking info for backward pass
|
||||
ctx.used_chunking = True
|
||||
ctx.num_chunks = num_chunks
|
||||
|
||||
else:
|
||||
# Original code path for shorter sequences
|
||||
if top_k_before_softmax:
|
||||
# Apply temperature to student logits
|
||||
if kd_temperature != 1.0:
|
||||
student_logits_for_kd = student_logits_for_kd / kd_temperature
|
||||
|
||||
# Gather student logits for top-k tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
)
|
||||
|
||||
# Compute softmax over gathered logits
|
||||
student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1)
|
||||
student_probs_topk = torch.exp(student_logprobs_topk)
|
||||
else:
|
||||
# Allocate output tensor for logprobs directly (always in float32)
|
||||
student_logprobs_topk = torch.empty(
|
||||
(batch_size, teacher_seq_len, top_k),
|
||||
dtype=torch.float32,
|
||||
device=student_logits.device,
|
||||
)
|
||||
|
||||
# Launch fused kernel directly
|
||||
grid = (batch_size * teacher_seq_len,)
|
||||
fused_logsumexp_logprobs_kernel[grid](
|
||||
student_logits_for_kd.contiguous(),
|
||||
student_logprobs_topk,
|
||||
target_token_ids.contiguous(),
|
||||
batch_size,
|
||||
teacher_seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
kd_temperature,
|
||||
student_logits_for_kd.stride(0),
|
||||
student_logits_for_kd.stride(1),
|
||||
student_logits_for_kd.stride(2),
|
||||
student_logprobs_topk.stride(0),
|
||||
student_logprobs_topk.stride(1),
|
||||
student_logprobs_topk.stride(2),
|
||||
target_token_ids.stride(0),
|
||||
target_token_ids.stride(1),
|
||||
target_token_ids.stride(2),
|
||||
min(1024, triton.next_power_of_2(vocab_size)),
|
||||
)
|
||||
|
||||
# Calculate probs from logprobs
|
||||
student_probs_topk = torch.exp(student_logprobs_topk)
|
||||
|
||||
# No chunking used
|
||||
ctx.used_chunking = False
|
||||
|
||||
# Save tensors for backward pass
|
||||
ctx.save_for_backward(
|
||||
student_logits_for_kd,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
student_probs_topk,
|
||||
)
|
||||
ctx.kd_temperature = kd_temperature
|
||||
ctx.top_k_before_softmax = top_k_before_softmax
|
||||
ctx.num_items_in_batch = num_items_in_batch
|
||||
|
||||
# Convert mask to boolean
|
||||
valid_mask = target_mask.bool()
|
||||
|
||||
# Extract valid tokens only - this is where the error was happening
|
||||
# Use cloned contiguous tensors and explicit indexing for safety
|
||||
student_logprobs_flat = student_logprobs_topk.view(-1, top_k)
|
||||
target_logprobs_flat = target_logprobs.view(-1, top_k)
|
||||
valid_mask_flat = valid_mask.view(-1, top_k)
|
||||
|
||||
# Gather valid indices explicitly to avoid illegal memory access
|
||||
valid_indices = torch.nonzero(valid_mask_flat.view(-1)).squeeze(-1)
|
||||
student_logprobs_valid = torch.index_select(
|
||||
student_logprobs_flat.view(-1), 0, valid_indices
|
||||
)
|
||||
target_logprobs_valid = torch.index_select(
|
||||
target_logprobs_flat.view(-1), 0, valid_indices
|
||||
)
|
||||
|
||||
# Convert teacher logprobs to probabilities
|
||||
teacher_probs_valid = torch.exp(target_logprobs_valid)
|
||||
|
||||
# Compute KL divergence loss
|
||||
token_losses = teacher_probs_valid * (
|
||||
target_logprobs_valid - student_logprobs_valid
|
||||
)
|
||||
kd_loss = token_losses.sum()
|
||||
|
||||
# Apply temperature scaling
|
||||
# pylint: disable=duplicate-code
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
|
||||
# Normalize by number of items or valid tokens
|
||||
if num_items_in_batch > 0:
|
||||
kd_loss = kd_loss / float(num_items_in_batch)
|
||||
else:
|
||||
num_valid_tokens = valid_indices.numel()
|
||||
kd_loss = kd_loss / float(num_valid_tokens if num_valid_tokens > 0 else 1)
|
||||
|
||||
return kd_loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
Optimized backward pass for KL divergence loss with proper dtype handling and chunking.
|
||||
"""
|
||||
(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
student_probs,
|
||||
) = ctx.saved_tensors
|
||||
kd_temperature = ctx.kd_temperature
|
||||
num_items_in_batch = ctx.num_items_in_batch
|
||||
original_dtype = ctx.original_dtype
|
||||
|
||||
# Get dimensions
|
||||
batch_size, _, vocab_size = student_logits.shape
|
||||
_, teacher_seq_len, top_k = target_token_ids.shape
|
||||
|
||||
# Initialize gradient tensor in float32 to support atomic operations
|
||||
grad_student_logits = torch.zeros_like(student_logits, dtype=torch.float32)
|
||||
|
||||
# Compute scaling factor
|
||||
scale = grad_output.item()
|
||||
|
||||
# Apply temperature scaling from forward pass
|
||||
if kd_temperature != 1.0:
|
||||
scale = scale * (kd_temperature**2)
|
||||
|
||||
# Normalize by number of items or valid tokens
|
||||
if num_items_in_batch > 0:
|
||||
scale = scale / float(num_items_in_batch)
|
||||
else:
|
||||
scale = scale / float(target_mask.sum().item())
|
||||
|
||||
# Apply chain rule for temperature scaling (1/temperature)
|
||||
if kd_temperature != 1.0:
|
||||
scale = scale / kd_temperature
|
||||
|
||||
# Convert teacher logprobs to probabilities
|
||||
teacher_probs = torch.exp(target_logprobs)
|
||||
|
||||
# Use chunking for the backward pass if used in forward
|
||||
if getattr(ctx, "used_chunking", False):
|
||||
num_chunks = ctx.num_chunks
|
||||
max_seq = TopKKLDivergence.MAX_SEQ_LEN
|
||||
|
||||
# Process each chunk
|
||||
for i in range(num_chunks):
|
||||
start_idx = i * max_seq
|
||||
end_idx = min((i + 1) * max_seq, teacher_seq_len)
|
||||
chunk_len = end_idx - start_idx
|
||||
|
||||
# Get chunk slices
|
||||
# student_logits_chunk = student_logits[:, start_idx:end_idx, :]
|
||||
target_token_ids_chunk = target_token_ids[:, start_idx:end_idx, :]
|
||||
teacher_probs_chunk = teacher_probs[:, start_idx:end_idx, :]
|
||||
student_probs_chunk = student_probs[:, start_idx:end_idx, :]
|
||||
target_mask_chunk = target_mask[:, start_idx:end_idx, :]
|
||||
grad_student_logits_chunk = grad_student_logits[:, start_idx:end_idx, :]
|
||||
|
||||
# Launch gradient computation kernel for this chunk
|
||||
grid = (batch_size * chunk_len,)
|
||||
grad_softmax_kernel[grid](
|
||||
grad_student_logits_chunk.contiguous(),
|
||||
target_token_ids_chunk.contiguous(),
|
||||
teacher_probs_chunk.contiguous(),
|
||||
student_probs_chunk.contiguous(),
|
||||
target_mask_chunk.contiguous(),
|
||||
batch_size,
|
||||
chunk_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
scale,
|
||||
grad_student_logits_chunk.stride(0),
|
||||
grad_student_logits_chunk.stride(1),
|
||||
grad_student_logits_chunk.stride(2),
|
||||
target_token_ids_chunk.stride(0),
|
||||
target_token_ids_chunk.stride(1),
|
||||
target_token_ids_chunk.stride(2),
|
||||
teacher_probs_chunk.stride(0),
|
||||
teacher_probs_chunk.stride(1),
|
||||
teacher_probs_chunk.stride(2),
|
||||
student_probs_chunk.stride(0),
|
||||
student_probs_chunk.stride(1),
|
||||
student_probs_chunk.stride(2),
|
||||
target_mask_chunk.stride(0),
|
||||
target_mask_chunk.stride(1),
|
||||
target_mask_chunk.stride(2),
|
||||
min(1024, triton.next_power_of_2(top_k)),
|
||||
)
|
||||
|
||||
# Update the gradient tensor (already in-place)
|
||||
else:
|
||||
# Original code path for shorter sequences
|
||||
# Launch gradient computation kernel
|
||||
grid = (batch_size * teacher_seq_len,)
|
||||
grad_softmax_kernel[grid](
|
||||
grad_student_logits.contiguous(),
|
||||
target_token_ids.contiguous(),
|
||||
teacher_probs.contiguous(),
|
||||
student_probs.contiguous(),
|
||||
target_mask.contiguous(),
|
||||
batch_size,
|
||||
teacher_seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
scale,
|
||||
grad_student_logits.stride(0),
|
||||
grad_student_logits.stride(1),
|
||||
grad_student_logits.stride(2),
|
||||
target_token_ids.stride(0),
|
||||
target_token_ids.stride(1),
|
||||
target_token_ids.stride(2),
|
||||
teacher_probs.stride(0),
|
||||
teacher_probs.stride(1),
|
||||
teacher_probs.stride(2),
|
||||
student_probs.stride(0),
|
||||
student_probs.stride(1),
|
||||
student_probs.stride(2),
|
||||
target_mask.stride(0),
|
||||
target_mask.stride(1),
|
||||
target_mask.stride(2),
|
||||
min(1024, triton.next_power_of_2(top_k)),
|
||||
)
|
||||
|
||||
# Convert gradient back to original dtype if needed
|
||||
if original_dtype != torch.float32:
|
||||
grad_student_logits = grad_student_logits.to(original_dtype)
|
||||
|
||||
# Return gradients for student_logits and None for other inputs
|
||||
return grad_student_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
# Wrapper function for chunked computation
|
||||
def loss(
|
||||
student_logits: torch.Tensor,
|
||||
target_token_ids: torch.Tensor,
|
||||
target_logprobs: torch.Tensor,
|
||||
target_mask: torch.Tensor,
|
||||
num_items_in_batch: int = -1,
|
||||
kd_temperature: float = 1.0,
|
||||
top_k_before_softmax: int = 0,
|
||||
max_seq_len: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Triton-accelerated Memory-efficient KL divergence loss computation for knowledge distillation
|
||||
with support for very long sequences.
|
||||
|
||||
Args:
|
||||
student_logits: Student logits [B, seq_len, vocab_size]
|
||||
target_token_ids: Teacher token IDs [B, seq_len, top_k]
|
||||
target_logprobs: Teacher logprobs [B, seq_len, top_k]
|
||||
target_mask: Token mask [B, seq_len, top_k]
|
||||
num_items_in_batch: Number of items for normalization (-1 for auto)
|
||||
kd_temperature: Temperature for KD
|
||||
top_k_before_softmax: Flag for softmax application order
|
||||
max_seq_len: Override default MAX_SEQ_LEN value for chunking
|
||||
"""
|
||||
# Allow overriding the max sequence length
|
||||
if max_seq_len is not None and max_seq_len > 0:
|
||||
TopKKLDivergence.MAX_SEQ_LEN = max_seq_len
|
||||
|
||||
total_loss = TopKKLDivergence.apply(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
-1 if num_items_in_batch <= 0 else num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
|
||||
return total_loss
|
||||
67
src/axolotl/integrations/kd/topk_logprob/logsumexp.py
Normal file
67
src/axolotl/integrations/kd/topk_logprob/logsumexp.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Optimized Triton kernels for logsumexp
|
||||
"""
|
||||
# pylint: disable=invalid-name,unused-argument
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# Helper function for computing logsumexp
|
||||
@triton.jit
|
||||
def logsumexp_kernel(
|
||||
logits_ptr,
|
||||
output_ptr,
|
||||
B,
|
||||
S,
|
||||
V, # batch size, seq len, vocab size
|
||||
stride_b,
|
||||
stride_s,
|
||||
stride_v,
|
||||
out_stride_b,
|
||||
out_stride_s,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Program ID
|
||||
# pylint: disable=duplicate-code
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Pointers
|
||||
logits_base = logits_ptr + batch_idx * stride_b + seq_idx * stride_s
|
||||
|
||||
# Find maximum for numerical stability
|
||||
max_val = -float("inf")
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
v_size = min(BLOCK_SIZE, V - v_offset)
|
||||
mask = tl.arange(0, BLOCK_SIZE) < v_size
|
||||
|
||||
logits_block = tl.load(
|
||||
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
|
||||
mask=mask,
|
||||
other=-float("inf"),
|
||||
)
|
||||
max_val = tl.maximum(max_val, tl.max(logits_block, axis=0))
|
||||
|
||||
# Compute sum of exp(logit - max_val)
|
||||
sum_exp = 0.0
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
v_size = min(BLOCK_SIZE, V - v_offset)
|
||||
mask = tl.arange(0, BLOCK_SIZE) < v_size
|
||||
|
||||
logits_block = tl.load(
|
||||
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
|
||||
mask=mask,
|
||||
other=-float("inf"),
|
||||
)
|
||||
sum_exp += tl.sum(tl.exp(logits_block - max_val), axis=0)
|
||||
|
||||
# Compute logsumexp
|
||||
result = max_val + tl.log(sum_exp)
|
||||
|
||||
# Store result
|
||||
tl.store(output_ptr + batch_idx * out_stride_b + seq_idx * out_stride_s, result)
|
||||
@@ -20,6 +20,7 @@ from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||
from .topk_logprob.forward_kl_triton import loss as topk_kd_loss_triton
|
||||
|
||||
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
@@ -85,7 +86,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
else:
|
||||
loss_kd = topk_kd_loss(
|
||||
loss_fn = (
|
||||
topk_kd_loss
|
||||
if self.args.kd_top_k_before_softmax
|
||||
else topk_kd_loss_triton
|
||||
)
|
||||
loss_kd = loss_fn(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
|
||||
@@ -90,6 +90,12 @@ class TestKnowledgeDistillation:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"load_in_8bit",
|
||||
@@ -121,3 +127,9 @@ class TestKnowledgeDistillation:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
|
||||
)
|
||||
|
||||
163
tests/e2e/integrations/test_kl_loss.py
Normal file
163
tests/e2e/integrations/test_kl_loss.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
sanity checks on kl loss and gradients
|
||||
"""
|
||||
import torch
|
||||
|
||||
# Import both implementations
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
|
||||
|
||||
|
||||
def test_kl_loss_gradient():
|
||||
"""Test that the gradient of the Triton implementation matches the eager implementation."""
|
||||
|
||||
# Set the random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create random inputs
|
||||
batch_size = 2
|
||||
seq_len = 3
|
||||
vocab_size = 100
|
||||
top_k = 5
|
||||
|
||||
# Generate random student logits
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
|
||||
)
|
||||
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
||||
|
||||
# Generate random target token IDs, ensuring they're valid indices
|
||||
# pylint: disable=duplicate-code
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
|
||||
# Generate random target logprobs (before normalization)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
|
||||
# Normalize the target logprobs to ensure they form a valid distribution
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
|
||||
# Create a random mask with some tokens masked out
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Additional parameters
|
||||
num_items_in_batch = batch_size * seq_len
|
||||
kd_temperature = 1.0
|
||||
top_k_before_softmax = 0 # Test both modes
|
||||
|
||||
# Compute the loss and gradients with eager implementation
|
||||
loss_eager = eager_loss(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_eager.backward()
|
||||
grad_eager = student_logits.grad.clone()
|
||||
|
||||
# Reset gradients
|
||||
student_logits.grad.zero_()
|
||||
|
||||
# Compute the loss and gradients with Triton implementation
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_triton.backward()
|
||||
grad_triton = student_logits_triton.grad.clone()
|
||||
|
||||
# Compare loss values
|
||||
print(f"Eager loss: {loss_eager.item()}")
|
||||
print(f"Triton loss: {loss_triton.item()}")
|
||||
loss_diff = abs(loss_eager.item() - loss_triton.item())
|
||||
print(f"Loss difference: {loss_diff}")
|
||||
assert loss_diff < 1e-5, "Loss values differ significantly!"
|
||||
|
||||
# Compare gradients
|
||||
grad_diff = (grad_eager - grad_triton).abs().max().item()
|
||||
print(f"Max gradient difference: {grad_diff}")
|
||||
|
||||
# Print some sample gradients
|
||||
sample_idx = (0, 0, 0) # (batch, seq, vocab)
|
||||
print(f"Sample eager gradient: {grad_eager[sample_idx].item()}")
|
||||
print(f"Sample triton gradient: {grad_triton[sample_idx].item()}")
|
||||
|
||||
# Compute relative difference for non-zero gradients
|
||||
mask = grad_eager.abs() > 1e-10
|
||||
if mask.sum() > 0:
|
||||
rel_diff = (
|
||||
(
|
||||
(grad_eager[mask] - grad_triton[mask]).abs()
|
||||
/ (grad_eager[mask].abs() + 1e-10)
|
||||
)
|
||||
.max()
|
||||
.item()
|
||||
)
|
||||
print(f"Max relative gradient difference: {rel_diff}")
|
||||
assert rel_diff < 1e-3, "Gradients differ significantly!"
|
||||
|
||||
# Also test top_k_before_softmax = 1 mode
|
||||
top_k_before_softmax = 1
|
||||
|
||||
# Reset the gradients
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
|
||||
)
|
||||
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
||||
|
||||
# Compute the loss and gradients with eager implementation
|
||||
loss_eager = eager_loss(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_eager.backward()
|
||||
grad_eager = student_logits.grad.clone()
|
||||
|
||||
# Compute the loss and gradients with Triton implementation
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_triton.backward()
|
||||
grad_triton = student_logits_triton.grad.clone()
|
||||
|
||||
# Compare gradients for top_k_before_softmax = 1
|
||||
grad_diff = (grad_eager - grad_triton).abs().max().item()
|
||||
print("\nWith top_k_before_softmax=1:")
|
||||
print(f"Max gradient difference: {grad_diff}")
|
||||
|
||||
# Compute relative difference for non-zero gradients
|
||||
mask = grad_eager.abs() > 1e-10
|
||||
if mask.sum() > 0:
|
||||
rel_diff = (
|
||||
(
|
||||
(grad_eager[mask] - grad_triton[mask]).abs()
|
||||
/ (grad_eager[mask].abs() + 1e-10)
|
||||
)
|
||||
.max()
|
||||
.item()
|
||||
)
|
||||
assert (
|
||||
rel_diff < 1e-3
|
||||
), f"Gradients differ significantly, Max relative gradient difference: {rel_diff}"
|
||||
204
tests/e2e/integrations/test_logsumexp.py
Normal file
204
tests/e2e/integrations/test_logsumexp.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
sanity checks on logsumexp kernel validity
|
||||
"""
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from axolotl.integrations.kd.topk_logprob.logsumexp import logsumexp_kernel
|
||||
|
||||
|
||||
# PyTorch implementation of logsumexp for reference
|
||||
def torch_logsumexp(logits):
|
||||
"""PyTorch implementation of logsumexp over last dimension"""
|
||||
return torch.logsumexp(logits, dim=-1)
|
||||
|
||||
|
||||
# Wrapper function for Triton logsumexp kernel
|
||||
def triton_logsumexp(logits):
|
||||
"""Triton implementation of logsumexp over last dimension"""
|
||||
B, S, V = logits.shape # pylint: disable=invalid-name
|
||||
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
|
||||
|
||||
grid = (B * S,)
|
||||
logsumexp_kernel[grid](
|
||||
logits.contiguous(),
|
||||
output,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
min(1024, triton.next_power_of_2(V)),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TritonLogSumExp(torch.autograd.Function):
|
||||
"""
|
||||
Wrap a custom autograd function to use the Triton logsumexp for gradient testing
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits):
|
||||
B, S, V = logits.shape # pylint: disable=invalid-name
|
||||
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
|
||||
|
||||
# Save inputs for backward pass
|
||||
ctx.save_for_backward(logits)
|
||||
ctx.shape = logits.shape
|
||||
|
||||
grid = (B * S,)
|
||||
logsumexp_kernel[grid](
|
||||
logits.contiguous(),
|
||||
output,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
min(1024, triton.next_power_of_2(V)),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(logits,) = ctx.saved_tensors
|
||||
|
||||
# For logsumexp, the gradient is softmax(input) * grad_output
|
||||
# First compute the logsumexp
|
||||
lse = TritonLogSumExp.apply(logits)
|
||||
|
||||
# Compute softmax by exponentiating differences
|
||||
softmax_output = torch.exp(logits - lse.unsqueeze(-1))
|
||||
|
||||
# Compute gradient of logsumexp by multiplying the softmax output by the gradient
|
||||
grad_input = softmax_output * grad_output.unsqueeze(-1)
|
||||
|
||||
return grad_input
|
||||
|
||||
|
||||
def test_logsumexp_values():
|
||||
"""Test that the Triton logsumexp implementation matches PyTorch's"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Test with various input shapes
|
||||
test_shapes = [
|
||||
(2, 3, 10), # small vocab
|
||||
(4, 5, 100), # medium vocab
|
||||
(2, 2, 32000), # large vocab (typical for LLMs)
|
||||
]
|
||||
|
||||
for shape in test_shapes:
|
||||
# Create random input tensors
|
||||
logits = torch.randn(shape, device="cuda", requires_grad=False)
|
||||
|
||||
# Compute logsumexp using both implementations
|
||||
torch_result = torch_logsumexp(logits)
|
||||
triton_result = triton_logsumexp(logits)
|
||||
|
||||
# Compare results
|
||||
max_diff = (torch_result - triton_result).abs().max().item()
|
||||
print(f"Shape {shape}, Max diff: {max_diff}")
|
||||
|
||||
# Assert that the results are very close
|
||||
assert max_diff < 1e-5, f"Results differ for shape {shape}: max diff {max_diff}"
|
||||
|
||||
|
||||
def test_logsumexp_edge_cases():
|
||||
"""Test edge cases for numerical stability"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Case 1: Very large values that might cause overflow
|
||||
logits_large = torch.ones(2, 3, 100, device="cuda") * 1000
|
||||
|
||||
# Case 2: Very small values that might cause underflow
|
||||
logits_small = torch.ones(2, 3, 100, device="cuda") * -1000
|
||||
|
||||
# Case 3: Mix of large and small values
|
||||
logits_mixed = torch.zeros(2, 3, 100, device="cuda")
|
||||
logits_mixed[:, :, 0] = 1000 # One very large value
|
||||
|
||||
# Case 4: All identical values
|
||||
logits_identical = torch.ones(2, 3, 100, device="cuda") * 5
|
||||
|
||||
# Case 5: Extreme values with NaN check
|
||||
logits_extreme = torch.cat(
|
||||
[
|
||||
torch.full((1, 3, 50), 1e10, device="cuda"),
|
||||
torch.full((1, 3, 50), -1e10, device="cuda"),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
for i, logits in enumerate(
|
||||
[logits_large, logits_small, logits_mixed, logits_identical, logits_extreme]
|
||||
):
|
||||
# Compute logsumexp using both implementations
|
||||
torch_result = torch_logsumexp(logits)
|
||||
triton_result = triton_logsumexp(logits)
|
||||
|
||||
# Check for NaNs
|
||||
assert not torch.isnan(
|
||||
torch_result
|
||||
).any(), f"PyTorch produced NaNs for case {i+1}"
|
||||
assert not torch.isnan(
|
||||
triton_result
|
||||
).any(), f"Triton produced NaNs for case {i+1}"
|
||||
|
||||
# Compare results
|
||||
max_diff = (torch_result - triton_result).abs().max().item()
|
||||
print(f"Edge case {i+1}, Max diff: {max_diff}")
|
||||
|
||||
# For very extreme values, allow a bit more tolerance
|
||||
if i == 4: # extreme case
|
||||
assert max_diff < 1e-2, f"Results differ too much for edge case {i+1}"
|
||||
else:
|
||||
assert max_diff < 1e-5, f"Results differ too much for edge case {i+1}"
|
||||
|
||||
|
||||
def test_logsumexp_gradients():
|
||||
"""Test that the gradients of Triton logsumexp match PyTorch's"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create input tensors with gradients enabled
|
||||
shapes = [(2, 3, 10), (4, 5, 100)]
|
||||
|
||||
for shape in shapes:
|
||||
# Create two identical tensors for PyTorch and Triton
|
||||
logits_torch = torch.randn(shape, device="cuda", requires_grad=True)
|
||||
logits_triton = logits_torch.clone().detach().requires_grad_(True)
|
||||
|
||||
# Forward pass
|
||||
torch_output = torch_logsumexp(logits_torch)
|
||||
triton_output = TritonLogSumExp.apply(logits_triton)
|
||||
|
||||
# Compare forward pass values
|
||||
max_diff_forward = (torch_output - triton_output).abs().max().item()
|
||||
assert max_diff_forward < 1e-5, f"Forward pass values differ for shape {shape}"
|
||||
|
||||
# Create random gradient
|
||||
grad_output = torch.randn_like(torch_output)
|
||||
|
||||
# Backward pass
|
||||
torch_output.backward(grad_output)
|
||||
triton_output.backward(grad_output)
|
||||
|
||||
# Compare gradients
|
||||
max_diff_grad = (logits_torch.grad - logits_triton.grad).abs().max().item()
|
||||
print(f"Shape {shape}, Max gradient diff: {max_diff_grad}")
|
||||
|
||||
# Assert that gradients are very close
|
||||
assert (
|
||||
max_diff_grad < 1e-5
|
||||
), f"Gradients differ for shape {shape}: max diff {max_diff_grad}"
|
||||
@@ -102,7 +102,11 @@ def is_hopper():
|
||||
|
||||
|
||||
def check_tensorboard(
|
||||
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
|
||||
temp_run_dir: str,
|
||||
tag: str,
|
||||
comparison_val: float,
|
||||
assertion_err: str,
|
||||
lt: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
helper function to parse and check tensorboard logs
|
||||
@@ -112,10 +116,20 @@ def check_tensorboard(
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
|
||||
if lt:
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] < comparison_val, (
|
||||
assertion_err % df.value.values[-1]
|
||||
)
|
||||
else:
|
||||
assert df.value.values[-1] < comparison_val, assertion_err
|
||||
else:
|
||||
assert df.value.values[-1] < lt_val, assertion_err
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] > comparison_val, (
|
||||
assertion_err % df.value.values[-1]
|
||||
)
|
||||
else:
|
||||
assert df.value.values[-1] > comparison_val, assertion_err
|
||||
|
||||
|
||||
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
|
||||
|
||||
Reference in New Issue
Block a user