chore: lint
This commit is contained in:
@@ -7,7 +7,6 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.integrations.kd.kernels.kd import kd_loss_triton
|
||||
|
||||
|
||||
def kd_loss_function(
|
||||
@@ -93,59 +92,59 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
if columns_to_add:
|
||||
self._signature_columns += columns_to_add
|
||||
|
||||
def compute_loss_w_triton(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
target_logprobs = inputs.pop("target_logprobs")
|
||||
target_token_ids = inputs.pop("target_token_ids")
|
||||
target_mask = inputs.pop("target_mask")
|
||||
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
outputs = model(**inputs)
|
||||
|
||||
student_logits = outputs["logits"]
|
||||
# Slice or gather student logits to match teacher seq len
|
||||
# e.g.:
|
||||
teacher_seq_len = target_token_ids.shape[1]
|
||||
student_logits_for_kd = student_logits[
|
||||
:, :teacher_seq_len, :
|
||||
] # [B, seq_len, vocab_size]
|
||||
|
||||
# GATHER top-K from student
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd,
|
||||
dim=-1,
|
||||
index=target_token_ids, # same shape [B, seq_len, K]
|
||||
)
|
||||
|
||||
# Now call the Triton-based KD loss
|
||||
kd_sum = kd_loss_triton(
|
||||
student_logits_topk,
|
||||
target_logprobs, # teacher logprobs [B, seq_len, K]
|
||||
target_mask, # mask [B, seq_len, K]
|
||||
)
|
||||
|
||||
# Normalize however you want
|
||||
if num_items_in_batch is not None:
|
||||
loss_kd = kd_sum / num_items_in_batch
|
||||
else:
|
||||
# or do e.g. average over valid tokens
|
||||
# quick example:
|
||||
total_valid = target_mask.sum()
|
||||
loss_kd = kd_sum / (total_valid + 1e-8)
|
||||
|
||||
# optionally combine with CE loss
|
||||
if self.args.kd_ce_alpha > 0:
|
||||
kd_alpha = self.args.kd_alpha
|
||||
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
||||
else:
|
||||
loss = loss_kd
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
# def compute_loss_w_triton(
|
||||
# self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
# ):
|
||||
# target_logprobs = inputs.pop("target_logprobs")
|
||||
# target_token_ids = inputs.pop("target_token_ids")
|
||||
# target_mask = inputs.pop("target_mask")
|
||||
#
|
||||
# if self.model_accepts_loss_kwargs:
|
||||
# loss_kwargs = {}
|
||||
# if num_items_in_batch is not None:
|
||||
# loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
# inputs = {**inputs, **loss_kwargs}
|
||||
# outputs = model(**inputs)
|
||||
#
|
||||
# student_logits = outputs["logits"]
|
||||
# # Slice or gather student logits to match teacher seq len
|
||||
# # e.g.:
|
||||
# teacher_seq_len = target_token_ids.shape[1]
|
||||
# student_logits_for_kd = student_logits[
|
||||
# :, :teacher_seq_len, :
|
||||
# ] # [B, seq_len, vocab_size]
|
||||
#
|
||||
# # GATHER top-K from student
|
||||
# student_logits_topk = torch.gather(
|
||||
# student_logits_for_kd,
|
||||
# dim=-1,
|
||||
# index=target_token_ids, # same shape [B, seq_len, K]
|
||||
# )
|
||||
#
|
||||
# # Now call the Triton-based KD loss
|
||||
# kd_sum = kd_loss_triton(
|
||||
# student_logits_topk,
|
||||
# target_logprobs, # teacher logprobs [B, seq_len, K]
|
||||
# target_mask, # mask [B, seq_len, K]
|
||||
# )
|
||||
#
|
||||
# # Normalize however you want
|
||||
# if num_items_in_batch is not None:
|
||||
# loss_kd = kd_sum / num_items_in_batch
|
||||
# else:
|
||||
# # or do e.g. average over valid tokens
|
||||
# # quick example:
|
||||
# total_valid = target_mask.sum()
|
||||
# loss_kd = kd_sum / (total_valid + 1e-8)
|
||||
#
|
||||
# # optionally combine with CE loss
|
||||
# if self.args.kd_ce_alpha > 0:
|
||||
# kd_alpha = self.args.kd_alpha
|
||||
# loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
||||
# else:
|
||||
# loss = loss_kd
|
||||
#
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
@@ -156,10 +155,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
|
||||
# return self.compute_loss_w_triton(
|
||||
# model, inputs, return_outputs, num_items_in_batch
|
||||
# )
|
||||
|
||||
target_logprobs = inputs.pop("target_logprobs")
|
||||
target_token_ids = inputs.pop("target_token_ids")
|
||||
target_mask = inputs.pop("target_mask")
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
"""
|
||||
Triton kernel for optimized kl divergence loss
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -37,10 +41,10 @@ def kd_forward_kernel(
|
||||
mask_ptr: tl.tensor,
|
||||
# partial_kd: [B*seq_len] flattened buffer to store partial sums
|
||||
partial_kd_ptr: tl.tensor,
|
||||
B: tl.int32,
|
||||
B: tl.int32, # pylint: disable=invalid-name
|
||||
seq_len: tl.int32,
|
||||
K: tl.int32,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
K: tl.int32, # pylint: disable=invalid-name
|
||||
BLOCK_SIZE: tl.constexpr, # pylint: disable=invalid-name
|
||||
):
|
||||
"""
|
||||
For each position in [0..B*seq_len), we:
|
||||
@@ -82,11 +86,7 @@ def kd_forward_kernel(
|
||||
|
||||
# load student logits, masked out-of-bounds with a large negative
|
||||
# so they don't affect the max
|
||||
student_val = tl.where(
|
||||
mask_pos,
|
||||
tl.load(student_logits_ptr + offset_k),
|
||||
-1e30
|
||||
)
|
||||
student_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
|
||||
# update running max
|
||||
max_val = tl.where(student_val > max_val, student_val, max_val)
|
||||
|
||||
@@ -96,11 +96,7 @@ def kd_forward_kernel(
|
||||
exp_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for k in range(K):
|
||||
offset_k = b_idx * (seq_len * K) + s_idx * K + k
|
||||
student_val = tl.where(
|
||||
mask_pos,
|
||||
tl.load(student_logits_ptr + offset_k),
|
||||
-1e30
|
||||
)
|
||||
student_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
|
||||
# exponent
|
||||
exponent = tl.exp(student_val - max_val)
|
||||
exp_sum += exponent
|
||||
@@ -119,20 +115,12 @@ def kd_forward_kernel(
|
||||
for k in range(K):
|
||||
offset_k = b_idx * (seq_len * K) + s_idx * K + k
|
||||
# teacher logprobs
|
||||
t_log = tl.where(
|
||||
mask_pos,
|
||||
tl.load(teacher_logprobs_ptr + offset_k),
|
||||
-1e30
|
||||
)
|
||||
t_log = tl.where(mask_pos, tl.load(teacher_logprobs_ptr + offset_k), -1e30)
|
||||
# teacher prob
|
||||
t_prob = tl.exp(t_log)
|
||||
|
||||
# student logit
|
||||
s_val = tl.where(
|
||||
mask_pos,
|
||||
tl.load(student_logits_ptr + offset_k),
|
||||
-1e30
|
||||
)
|
||||
s_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
|
||||
# student logprob
|
||||
s_logprob = s_val - logsumexp_val
|
||||
|
||||
@@ -142,7 +130,7 @@ def kd_forward_kernel(
|
||||
# also read mask to disable invalid tokens if mask is not purely sequence-based
|
||||
valid_k = tl.load(mask_ptr + offset_k)
|
||||
# if mask is bool => use 'valid_k != 0', if it's 0/1 => same
|
||||
is_valid = (valid_k > 0)
|
||||
is_valid = valid_k > 0
|
||||
|
||||
# zero out if either this index is out-of-bounds or mask is invalid
|
||||
kl_val = tl.where(mask_pos & is_valid, kl_val, 0.0)
|
||||
@@ -158,17 +146,17 @@ def kd_forward_kernel(
|
||||
|
||||
|
||||
def kd_forward_pass_triton(
|
||||
student_logits, # [B, seq_len, K] (already gathered)
|
||||
student_logits, # [B, seq_len, K] (already gathered)
|
||||
teacher_logprobs, # [B, seq_len, K]
|
||||
mask, # [B, seq_len, K] bool or 0/1
|
||||
BLOCK_SIZE=1024,
|
||||
mask, # [B, seq_len, K] bool or 0/1
|
||||
BLOCK_SIZE=1024, # pylint: disable=invalid-name
|
||||
):
|
||||
"""
|
||||
Returns total KL (float). We do the sum on the Python side.
|
||||
NOTE: No normalization is done here.
|
||||
You might divide by `num_items_in_batch` or # valid tokens afterward.
|
||||
"""
|
||||
B, seq_len, K = student_logits.shape
|
||||
B, seq_len, K = student_logits.shape # pylint: disable=invalid-name
|
||||
# Flatten
|
||||
student_logits_flat = student_logits.reshape(-1)
|
||||
teacher_logprobs_flat = teacher_logprobs.reshape(-1)
|
||||
@@ -188,14 +176,17 @@ def kd_forward_pass_triton(
|
||||
teacher_logprobs_flat,
|
||||
mask_flat,
|
||||
partial_kd,
|
||||
B, seq_len, K,
|
||||
BLOCK_SIZE=BLOCK_SIZE
|
||||
B,
|
||||
seq_len,
|
||||
K,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Sum on CPU or GPU
|
||||
kd_sum = partial_kd.sum()
|
||||
return kd_sum
|
||||
|
||||
|
||||
class _KLDivergenceTritonFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, student_logits, teacher_logprobs, mask):
|
||||
@@ -211,7 +202,6 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
|
||||
ctx.save_for_backward(student_logits, teacher_logprobs, mask)
|
||||
return kd_loss
|
||||
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# We'll do naive PyTorch re-computation for gradient wrt student_logits
|
||||
@@ -244,7 +234,7 @@ def kd_loss_triton(
|
||||
student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K
|
||||
teacher_logprobs,
|
||||
mask,
|
||||
num_items_in_batch=None,
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
"""
|
||||
Wrapper that calls our Triton-based forward+backward for KD.
|
||||
@@ -253,5 +243,7 @@ def kd_loss_triton(
|
||||
called gather on student_logits -> shape [B, seq_len, K].
|
||||
"""
|
||||
return _KLDivergenceTritonFn.apply(
|
||||
student_logits, teacher_logprobs, mask, # num_items_in_batch
|
||||
student_logits,
|
||||
teacher_logprobs,
|
||||
mask, # num_items_in_batch
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user