remove references to triton kd for now

This commit is contained in:
Wing Lian
2024-12-30 10:40:05 -05:00
parent cdb167e7f7
commit 7c4ae15942
2 changed files with 0 additions and 303 deletions

View File

@@ -40,60 +40,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
if columns_to_add:
self._signature_columns += columns_to_add
# def compute_loss_w_triton(
# self, model, inputs, return_outputs=False, num_items_in_batch=None
# ):
# target_logprobs = inputs.pop("target_logprobs")
# target_token_ids = inputs.pop("target_token_ids")
# target_mask = inputs.pop("target_mask")
#
# if self.model_accepts_loss_kwargs:
# loss_kwargs = {}
# if num_items_in_batch is not None:
# loss_kwargs["num_items_in_batch"] = num_items_in_batch
# inputs = {**inputs, **loss_kwargs}
# outputs = model(**inputs)
#
# student_logits = outputs["logits"]
# # Slice or gather student logits to match teacher seq len
# # e.g.:
# teacher_seq_len = target_token_ids.shape[1]
# student_logits_for_kd = student_logits[
# :, :teacher_seq_len, :
# ] # [B, seq_len, vocab_size]
#
# # GATHER top-K from student
# student_logits_topk = torch.gather(
# student_logits_for_kd,
# dim=-1,
# index=target_token_ids, # same shape [B, seq_len, K]
# )
#
# # Now call the Triton-based KD loss
# kd_sum = kd_loss_triton(
# student_logits_topk,
# target_logprobs, # teacher logprobs [B, seq_len, K]
# target_mask, # mask [B, seq_len, K]
# )
#
# # Normalize however you want
# if num_items_in_batch is not None:
# loss_kd = kd_sum / num_items_in_batch
# else:
# # or do e.g. average over valid tokens
# # quick example:
# total_valid = target_mask.sum()
# loss_kd = kd_sum / (total_valid + 1e-8)
#
# # optionally combine with CE loss
# if self.args.kd_ce_alpha > 0:
# kd_alpha = self.args.kd_alpha
# loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
# else:
# loss = loss_kd
#
# return (loss, outputs) if return_outputs else loss
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):

View File

@@ -1,249 +0,0 @@
"""
Triton kernel for optimized kl divergence loss
"""
import torch
import triton
import triton.language as tl
# --------------------------------------------------------
# Triton Kernel for forward pass
# --------------------------------------------------------
# We'll assume:
# - B * seq_len threads in 1D dimension
# - Each thread handles K tokens (the top-K from teacher).
# - For large K, you might want a more 2D approach to keep good occupancy.
#
# Pseudocode steps inside kernel:
# 1) compute index for [batch, seq_position]
# 2) read top-K token IDs from teacher_token_ids
# 3) gather student_logits_topk
# 4) compute logsumexp for those K logits
# 5) compute student_logprobs_topk
# 6) read teacher_logprobs
# 7) compute teacher_probs = exp(teacher_logprobs)
# 8) compute partial KL = sum(teacher_probs * (teacher_logprobs - student_logprobs_topk))
# 9) store partial KL in a buffer
#
# Later, we'll do a reduction on partial KL across all threads.
#
# NOTE: This is a reference skeleton. You must adapt indexing carefully.
#
@triton.jit
def kd_forward_kernel(
# student_logits after gather: [B, seq_len, K] flattened to 1D in row-major
student_logits_ptr: tl.tensor,
# teacher_logprobs: [B, seq_len, K] flattened
teacher_logprobs_ptr: tl.tensor,
# mask: [B, seq_len, K] flattened (bool or 0/1)
mask_ptr: tl.tensor,
# partial_kd: [B*seq_len] flattened buffer to store partial sums
partial_kd_ptr: tl.tensor,
B: tl.int32, # pylint: disable=invalid-name
seq_len: tl.int32,
K: tl.int32, # pylint: disable=invalid-name
BLOCK_SIZE: tl.constexpr, # pylint: disable=invalid-name
):
"""
For each position in [0..B*seq_len), we:
- gather the K student logits
- compute logsumexp
- compute the KL sum = sum_{k} t_prob_k * ( t_log_k - s_logprob_k )
- store that partial sum into partial_kd_ptr[offset].
"""
# 1) Identify which [B*seq_len] index this block handles
pid = tl.program_id(0)
# 2) Vector of [0..BLOCK_SIZE) local offsets
offsets = tl.arange(0, BLOCK_SIZE)
# 3) Global indices = pid * BLOCK_SIZE + offsets
idx = pid * BLOCK_SIZE + offsets
# 4) Mask to ensure we dont read out-of-bounds
total_positions = B * seq_len
mask_pos = idx < total_positions
# 5) Convert a 1D `idx` => (b_idx, s_idx)
# b_idx is the batch number, s_idx is the sequence position
b_idx = idx // seq_len
s_idx = idx % seq_len
# We'll accumulate the KL for each index in a register array
kl_acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
# -------------------------------------------------------------------------
# First pass: find max logits over K to implement logsumexp
# -------------------------------------------------------------------------
max_val = tl.full([BLOCK_SIZE], -1e30, dtype=tl.float32)
# Python-level loops are allowed in Triton as long as the
# operations inside are Triton ops, not torch or Python math.
for k in range(K):
# pointer offset in the flattened [B, seq_len, K] = b_idx*(seq_len*K) + s_idx*K + k
offset_k = b_idx * (seq_len * K) + s_idx * K + k
# load student logits, masked out-of-bounds with a large negative
# so they don't affect the max
student_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
# update running max
max_val = tl.where(student_val > max_val, student_val, max_val)
# -------------------------------------------------------------------------
# Second pass: sum of exp(...) to complete logsumexp
# -------------------------------------------------------------------------
exp_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for k in range(K):
offset_k = b_idx * (seq_len * K) + s_idx * K + k
student_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
# exponent
exponent = tl.exp(student_val - max_val)
exp_sum += exponent
# final logsumexp
logsumexp_val = max_val + tl.log(exp_sum)
# -------------------------------------------------------------------------
# Third pass: compute partial KL per position
# KL = sum_{k in valid} p^T_k * (teacher_logprobs_k - student_logprobs_k)
#
# - teacher_logprobs_k => t_log
# - teacher_prob_k = exp(t_log)
# - student_logprobs_k = s_val - logsumexp_val
# -------------------------------------------------------------------------
for k in range(K):
offset_k = b_idx * (seq_len * K) + s_idx * K + k
# teacher logprobs
t_log = tl.where(mask_pos, tl.load(teacher_logprobs_ptr + offset_k), -1e30)
# teacher prob
t_prob = tl.exp(t_log)
# student logit
s_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
# student logprob
s_logprob = s_val - logsumexp_val
# local KL
kl_val = t_prob * (t_log - s_logprob)
# also read mask to disable invalid tokens if mask is not purely sequence-based
valid_k = tl.load(mask_ptr + offset_k)
# if mask is bool => use 'valid_k != 0', if it's 0/1 => same
is_valid = valid_k > 0
# zero out if either this index is out-of-bounds or mask is invalid
kl_val = tl.where(mask_pos & is_valid, kl_val, 0.0)
# accumulate
kl_acc += kl_val
# -------------------------------------------------------------------------
# Store the partial KL in partial_kd_ptr for each element in idx.
# Later in Python, you can do partial_kd.sum() to get the total KL.
# -------------------------------------------------------------------------
tl.store(partial_kd_ptr + idx, kl_acc, mask=mask_pos)
def kd_forward_pass_triton(
student_logits, # [B, seq_len, K] (already gathered)
teacher_logprobs, # [B, seq_len, K]
mask, # [B, seq_len, K] bool or 0/1
BLOCK_SIZE=1024, # pylint: disable=invalid-name
):
"""
Returns total KL (float). We do the sum on the Python side.
NOTE: No normalization is done here.
You might divide by `num_items_in_batch` or # valid tokens afterward.
"""
B, seq_len, K = student_logits.shape # pylint: disable=invalid-name
# Flatten
student_logits_flat = student_logits.reshape(-1)
teacher_logprobs_flat = teacher_logprobs.reshape(-1)
mask_flat = mask.reshape(-1)
total_positions = B * seq_len
# We'll store partial KL sums for each of the B*seq_len positions
partial_kd = torch.empty(
total_positions, dtype=student_logits.dtype, device=student_logits.device
)
# Grid config
grid = ((total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE,)
kd_forward_kernel[grid](
student_logits_flat,
teacher_logprobs_flat,
mask_flat,
partial_kd,
B,
seq_len,
K,
BLOCK_SIZE=BLOCK_SIZE,
)
# Sum on CPU or GPU
kd_sum = partial_kd.sum()
return kd_sum
class _KLDivergenceTritonFn(torch.autograd.Function):
@staticmethod
def forward(ctx, student_logits, teacher_logprobs, mask):
"""
student_logits: (B, seq_len, K)
teacher_logprobs: (B, seq_len, K)
mask: (B, seq_len, K)
"""
kd_sum = kd_forward_pass_triton(student_logits, teacher_logprobs, mask)
kd_loss = kd_sum # Not normalized here. You can do that externally.
# Save for backward
ctx.save_for_backward(student_logits, teacher_logprobs, mask)
return kd_loss
@staticmethod
def backward(ctx, grad_output):
# We'll do naive PyTorch re-computation for gradient wrt student_logits
student_logits, teacher_logprobs, mask = ctx.saved_tensors
# grad_output is dLoss/dOut => a scalar
# Lets compute dLoss/dStudentLogits with the same formula as your original code
with torch.enable_grad():
stl = student_logits.clone().detach().requires_grad_(True)
t_log = teacher_logprobs
# mask might be bool or 0/1
# compute logsumexp
lse = torch.logsumexp(stl, dim=-1, keepdim=True)
s_logprob = stl - lse
t_prob = t_log.exp()
# forward KL = sum_{k} p^T_k ( t_log_k - s_logprob_k )
kl_val = t_prob * (t_log - s_logprob)
# mask out
kl_val = kl_val * mask # zero out invalid
kd_loss = kl_val.sum()
# now compute dLoss/d stl
grad_stl = torch.autograd.grad(kd_loss, stl, grad_outputs=grad_output)[0]
return grad_stl, None, None
def kd_loss_triton(
student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K
teacher_logprobs,
mask,
num_items_in_batch=None, # pylint: disable=unused-argument
):
"""
Wrapper that calls our Triton-based forward+backward for KD.
For production, you likely want to do the gather (teacher top-K) outside
or inside a separate kernel. This function expects that you've *already*
called gather on student_logits -> shape [B, seq_len, K].
"""
return _KLDivergenceTritonFn.apply(
student_logits,
teacher_logprobs,
mask, # num_items_in_batch
)