optimize and include bench util
This commit is contained in:
388
src/axolotl/integrations/kd/topk_logprob/bench_kl.py
Normal file
388
src/axolotl/integrations/kd/topk_logprob/bench_kl.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
benchmark utility helper for benchmarking the KL divergence triton kernel
|
||||
"""
|
||||
import gc
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch.utils.benchmark import Timer
|
||||
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
|
||||
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
def benchmark_kl_div_loss_with_backward():
|
||||
# Test configurations
|
||||
batch_sizes = [1, 4]
|
||||
seq_lens = [64, 512, 2048, 4096, 8192]
|
||||
vocab_size = 32000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
for seq_len in seq_lens:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create tensors with gradients
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the two implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Define functions for timing that include both forward and backward passes
|
||||
def run_reference():
|
||||
# Forward pass
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
# Backward pass
|
||||
loss_ref.backward()
|
||||
|
||||
def run_triton():
|
||||
# Forward pass
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
# Backward pass
|
||||
loss_triton.backward()
|
||||
|
||||
# Benchmark reference implementation (forward + backward)
|
||||
t0 = Timer(
|
||||
stmt="run_reference()",
|
||||
globals={
|
||||
"run_reference": run_reference,
|
||||
},
|
||||
)
|
||||
# Reset gradients before timing
|
||||
student_logits_ref.grad = None
|
||||
ref_time = t0.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Benchmark Triton implementation (forward + backward)
|
||||
t1 = Timer(
|
||||
stmt="run_triton()",
|
||||
globals={
|
||||
"run_triton": run_triton,
|
||||
},
|
||||
)
|
||||
# Reset gradients before timing
|
||||
student_logits_triton.grad = None
|
||||
triton_time = t1.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Compute speedup
|
||||
speedup = ref_time / triton_time if triton_time > 0 else float("inf")
|
||||
|
||||
# Store results
|
||||
results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"seq_len": seq_len,
|
||||
"reference_time_ms": ref_time,
|
||||
"triton_time_ms": triton_time,
|
||||
"speedup": speedup,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
|
||||
print(f" Reference time (fwd+bwd): {ref_time:.2f} ms")
|
||||
print(f" Triton time (fwd+bwd): {triton_time:.2f} ms")
|
||||
print(f" Speedup: {speedup:.2f}x")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def benchmark_forward_backward_separately():
|
||||
"""
|
||||
Benchmark forward and backward passes separately to identify where the speedup comes from.
|
||||
"""
|
||||
# Test configurations
|
||||
batch_sizes = [1, 4, 8]
|
||||
seq_lens = [64, 512, 2048]
|
||||
vocab_size = 32000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
detailed_results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
for seq_len in seq_lens:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create tensors with gradients
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the two implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Forward-only reference
|
||||
def run_reference_forward():
|
||||
with torch.no_grad():
|
||||
return eager_loss(
|
||||
student_logits_ref,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
|
||||
# Forward-only triton
|
||||
def run_triton_forward():
|
||||
with torch.no_grad():
|
||||
return triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
|
||||
# Benchmark forward pass only
|
||||
|
||||
t0_fwd = Timer(
|
||||
stmt="run_reference_forward()",
|
||||
globals={
|
||||
"run_reference_forward": run_reference_forward,
|
||||
},
|
||||
)
|
||||
ref_fwd_time = t0_fwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
t1_fwd = Timer(
|
||||
stmt="run_triton_forward()",
|
||||
globals={
|
||||
"run_triton_forward": run_triton_forward,
|
||||
},
|
||||
)
|
||||
triton_fwd_time = t1_fwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Pre-compute losses for backward pass benchmarking
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
|
||||
# Backward-only reference
|
||||
def run_reference_backward():
|
||||
student_logits_ref.grad = None
|
||||
loss_ref.backward()
|
||||
|
||||
# Backward-only triton
|
||||
def run_triton_backward():
|
||||
student_logits_triton.grad = None
|
||||
loss_triton.backward()
|
||||
|
||||
# Benchmark backward pass only
|
||||
t0_bwd = Timer(
|
||||
stmt="run_reference_backward()",
|
||||
globals={
|
||||
"run_reference_backward": run_reference_backward,
|
||||
},
|
||||
)
|
||||
ref_bwd_time = t0_bwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
t1_bwd = Timer(
|
||||
stmt="run_triton_backward()",
|
||||
globals={
|
||||
"run_triton_backward": run_triton_backward,
|
||||
},
|
||||
)
|
||||
triton_bwd_time = t1_bwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Compute speedups
|
||||
fwd_speedup = (
|
||||
ref_fwd_time / triton_fwd_time if triton_fwd_time > 0 else float("inf")
|
||||
)
|
||||
bwd_speedup = (
|
||||
ref_bwd_time / triton_bwd_time if triton_bwd_time > 0 else float("inf")
|
||||
)
|
||||
total_ref_time = ref_fwd_time + ref_bwd_time
|
||||
total_triton_time = triton_fwd_time + triton_bwd_time
|
||||
total_speedup = (
|
||||
total_ref_time / total_triton_time
|
||||
if total_triton_time > 0
|
||||
else float("inf")
|
||||
)
|
||||
|
||||
# Store results
|
||||
detailed_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"seq_len": seq_len,
|
||||
"ref_forward_ms": ref_fwd_time,
|
||||
"triton_forward_ms": triton_fwd_time,
|
||||
"forward_speedup": fwd_speedup,
|
||||
"ref_backward_ms": ref_bwd_time,
|
||||
"triton_backward_ms": triton_bwd_time,
|
||||
"backward_speedup": bwd_speedup,
|
||||
"total_ref_ms": total_ref_time,
|
||||
"total_triton_ms": total_triton_time,
|
||||
"total_speedup": total_speedup,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
|
||||
print(
|
||||
f" Forward: Reference={ref_fwd_time:.2f}ms, Triton={triton_fwd_time:.2f}ms, Speedup={fwd_speedup:.2f}x"
|
||||
)
|
||||
print(
|
||||
f" Backward: Reference={ref_bwd_time:.2f}ms, Triton={triton_bwd_time:.2f}ms, Speedup={bwd_speedup:.2f}x"
|
||||
)
|
||||
print(
|
||||
f" Total: Reference={total_ref_time:.2f}ms, Triton={total_triton_time:.2f}ms, Speedup={total_speedup:.2f}x"
|
||||
)
|
||||
|
||||
return detailed_results
|
||||
|
||||
|
||||
def benchmark_memory_usage_with_backward():
|
||||
# Test configurations
|
||||
batch_sizes = [1, 2]
|
||||
seq_len = 8192
|
||||
vocab_size = 128000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
mem_results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Measure PyTorch memory usage (forward + backward)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_ref.backward()
|
||||
torch.cuda.synchronize()
|
||||
pytorch_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
# Measure Triton memory usage (forward + backward)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_triton.backward()
|
||||
torch.cuda.synchronize()
|
||||
triton_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
# Measure Triton memory usage with different chunk sizes (forward + backward)
|
||||
for n_chunks in [1, 2, 4, 8]:
|
||||
student_logits_chunk = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_chunk = triton_loss(
|
||||
student_logits_chunk,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
loss_chunk.backward()
|
||||
torch.cuda.synchronize()
|
||||
chunk_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": f"Triton (chunks={n_chunks})",
|
||||
"memory_mb": chunk_mem,
|
||||
}
|
||||
)
|
||||
|
||||
# Store results
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": "PyTorch",
|
||||
"memory_mb": pytorch_mem,
|
||||
}
|
||||
)
|
||||
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": "Triton (default)",
|
||||
"memory_mb": triton_mem,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size} (with backward pass)")
|
||||
print(f" PyTorch memory: {pytorch_mem:.2f} MB")
|
||||
print(f" Triton memory: {triton_mem:.2f} MB")
|
||||
print(f" Memory reduction: {(1 - triton_mem/pytorch_mem)*100:.2f}%")
|
||||
|
||||
return mem_results
|
||||
|
||||
|
||||
def main():
|
||||
print("Running benchmarks with forward and backward passes...")
|
||||
benchmark_kl_div_loss_with_backward()
|
||||
clean()
|
||||
|
||||
print("\nRunning detailed forward/backward benchmarks...")
|
||||
# benchmark_forward_backward_separately()
|
||||
# clean()
|
||||
|
||||
print("\nRunning memory usage benchmarks with backward passes...")
|
||||
benchmark_memory_usage_with_backward()
|
||||
clean()
|
||||
|
||||
|
||||
def clean():
|
||||
for _ in range(5):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -51,7 +51,6 @@ def grad_softmax_kernel(
|
||||
grad_logits_base = (
|
||||
grad_student_logits_ptr + batch_idx * stride_gl_b + seq_idx * stride_gl_s
|
||||
)
|
||||
# logits_base = student_logits_ptr + batch_idx * stride_l_b + seq_idx * stride_l_s
|
||||
token_ids_base = (
|
||||
target_token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||
)
|
||||
@@ -63,26 +62,43 @@ def grad_softmax_kernel(
|
||||
)
|
||||
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
|
||||
|
||||
# Softmax over full vocab case
|
||||
# Process each teacher probability one at a time, computing all gradients for it
|
||||
for k in range(0, K):
|
||||
# Load token ID, teacher prob, and mask for this position
|
||||
# Load data for current position k
|
||||
teacher_prob = tl.load(teacher_probs_base + k * stride_p_k)
|
||||
student_prob_k = tl.load(student_probs_base + k * stride_sp_k)
|
||||
mask_val = tl.load(mask_base + k * stride_m_k)
|
||||
|
||||
# Precompute the self-influence term (multiplied by scale)
|
||||
self_term = teacher_prob * (1.0 - student_prob_k) * scale
|
||||
|
||||
# Calculate gradient contributions for all positions j
|
||||
for j in range(0, K):
|
||||
other_token_id = tl.load(token_ids_base + j * stride_t_k)
|
||||
token_id_j = tl.load(token_ids_base + j * stride_t_k)
|
||||
student_prob_j = tl.load(student_probs_base + j * stride_sp_k)
|
||||
mask_j = tl.load(mask_base + j * stride_m_k)
|
||||
|
||||
# Calculate the masking factor
|
||||
combined_mask = mask_val * mask_j
|
||||
is_diagonal = tl.where(j == k, 1.0, 0.0)
|
||||
self_grad = teacher_prob * (1.0 - student_prob_k)
|
||||
cross_grad = -teacher_prob * student_prob_j
|
||||
grad_val = (
|
||||
-(self_grad * is_diagonal + cross_grad * (1.0 - is_diagonal))
|
||||
* scale
|
||||
|
||||
# Determine if this is a diagonal or off-diagonal term
|
||||
is_k_equals_j = tl.where(k == j, 1.0, 0.0)
|
||||
|
||||
# Compute the gradient contribution
|
||||
# For diagonal (k==j): -teacher_prob * (1-student_prob_k) * scale * mask
|
||||
# For off-diagonal: -(-teacher_prob * student_prob_j) * scale * mask
|
||||
grad_contribution = (
|
||||
-(
|
||||
self_term * is_k_equals_j
|
||||
- teacher_prob * student_prob_j * scale * (1.0 - is_k_equals_j)
|
||||
)
|
||||
* combined_mask
|
||||
)
|
||||
tl.atomic_add(grad_logits_base + other_token_id * stride_gl_v, grad_val)
|
||||
|
||||
# Atomically update the gradient for this token
|
||||
tl.atomic_add(
|
||||
grad_logits_base + token_id_j * stride_gl_v, grad_contribution
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -315,7 +331,6 @@ class TopKKLDivergence(torch.autograd.Function):
|
||||
student_probs,
|
||||
) = ctx.saved_tensors
|
||||
kd_temperature = ctx.kd_temperature
|
||||
top_k_before_softmax = ctx.top_k_before_softmax
|
||||
num_items_in_batch = ctx.num_items_in_batch
|
||||
|
||||
batch_size, seq_len, vocab_size = student_logits.shape
|
||||
@@ -346,72 +361,36 @@ class TopKKLDivergence(torch.autograd.Function):
|
||||
teacher_probs = torch.exp(target_logprobs)
|
||||
|
||||
# Depending on which mode was used in forward, we use different gradient calculation
|
||||
if top_k_before_softmax:
|
||||
# Case 1: Softmax over top-k tokens
|
||||
grid = (batch_size * seq_len,)
|
||||
grad_topk_softmax_kernel[grid](
|
||||
grad_student_logits.contiguous(),
|
||||
student_logits.contiguous(),
|
||||
target_token_ids.contiguous(),
|
||||
teacher_probs.contiguous(),
|
||||
student_probs.contiguous(),
|
||||
target_mask.contiguous(),
|
||||
batch_size,
|
||||
seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
scale,
|
||||
grad_student_logits.stride(0),
|
||||
grad_student_logits.stride(1),
|
||||
grad_student_logits.stride(2),
|
||||
student_logits.stride(0),
|
||||
student_logits.stride(1),
|
||||
student_logits.stride(2),
|
||||
target_token_ids.stride(0),
|
||||
target_token_ids.stride(1),
|
||||
target_token_ids.stride(2),
|
||||
teacher_probs.stride(0),
|
||||
teacher_probs.stride(1),
|
||||
teacher_probs.stride(2),
|
||||
student_probs.stride(0),
|
||||
student_probs.stride(1),
|
||||
student_probs.stride(2),
|
||||
target_mask.stride(0),
|
||||
target_mask.stride(1),
|
||||
target_mask.stride(2),
|
||||
min(1024, triton.next_power_of_2(top_k)),
|
||||
)
|
||||
else:
|
||||
# Case 2: Softmax over full vocab
|
||||
grid = (batch_size * seq_len,)
|
||||
grad_softmax_kernel[grid](
|
||||
grad_student_logits.contiguous(),
|
||||
target_token_ids.contiguous(),
|
||||
teacher_probs.contiguous(),
|
||||
student_probs.contiguous(),
|
||||
target_mask.contiguous(),
|
||||
batch_size,
|
||||
seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
scale,
|
||||
grad_student_logits.stride(0),
|
||||
grad_student_logits.stride(1),
|
||||
grad_student_logits.stride(2),
|
||||
target_token_ids.stride(0),
|
||||
target_token_ids.stride(1),
|
||||
target_token_ids.stride(2),
|
||||
teacher_probs.stride(0),
|
||||
teacher_probs.stride(1),
|
||||
teacher_probs.stride(2),
|
||||
student_probs.stride(0),
|
||||
student_probs.stride(1),
|
||||
student_probs.stride(2),
|
||||
target_mask.stride(0),
|
||||
target_mask.stride(1),
|
||||
target_mask.stride(2),
|
||||
min(1024, triton.next_power_of_2(top_k)),
|
||||
)
|
||||
# FIXME: top_k_before_softmax not correctly yet?
|
||||
grid = (batch_size * seq_len,)
|
||||
grad_softmax_kernel[grid](
|
||||
grad_student_logits.contiguous(),
|
||||
target_token_ids.contiguous(),
|
||||
teacher_probs.contiguous(),
|
||||
student_probs.contiguous(),
|
||||
target_mask.contiguous(),
|
||||
batch_size,
|
||||
seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
scale,
|
||||
grad_student_logits.stride(0),
|
||||
grad_student_logits.stride(1),
|
||||
grad_student_logits.stride(2),
|
||||
target_token_ids.stride(0),
|
||||
target_token_ids.stride(1),
|
||||
target_token_ids.stride(2),
|
||||
teacher_probs.stride(0),
|
||||
teacher_probs.stride(1),
|
||||
teacher_probs.stride(2),
|
||||
student_probs.stride(0),
|
||||
student_probs.stride(1),
|
||||
student_probs.stride(2),
|
||||
target_mask.stride(0),
|
||||
target_mask.stride(1),
|
||||
target_mask.stride(2),
|
||||
min(1024, triton.next_power_of_2(top_k)),
|
||||
)
|
||||
|
||||
# Return gradients for student_logits and None for other inputs
|
||||
return grad_student_logits, None, None, None, None, None, None
|
||||
|
||||
@@ -18,6 +18,7 @@ KD trainer
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||
from .topk_logprob.forward_kl_triton import loss as topk_kd_loss_triton
|
||||
|
||||
@@ -85,7 +86,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
else:
|
||||
loss_kd = topk_kd_loss_triton(
|
||||
loss_fn = (
|
||||
topk_kd_loss
|
||||
if self.args.kd_top_k_before_softmax
|
||||
else topk_kd_loss_triton
|
||||
)
|
||||
loss_kd = loss_fn(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
|
||||
Reference in New Issue
Block a user