chore: lint

This commit is contained in:
Wing Lian
2024-12-28 16:02:06 -05:00
parent 7b5a24b0d2
commit 3f97ec45fb
2 changed files with 78 additions and 91 deletions

View File

@@ -7,7 +7,6 @@ from typing import Optional
import torch import torch
from axolotl.core.trainers.base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.integrations.kd.kernels.kd import kd_loss_triton
def kd_loss_function( def kd_loss_function(
@@ -93,59 +92,59 @@ class AxolotlKDTrainer(AxolotlTrainer):
if columns_to_add: if columns_to_add:
self._signature_columns += columns_to_add self._signature_columns += columns_to_add
def compute_loss_w_triton( # def compute_loss_w_triton(
self, model, inputs, return_outputs=False, num_items_in_batch=None # self, model, inputs, return_outputs=False, num_items_in_batch=None
): # ):
target_logprobs = inputs.pop("target_logprobs") # target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids") # target_token_ids = inputs.pop("target_token_ids")
target_mask = inputs.pop("target_mask") # target_mask = inputs.pop("target_mask")
#
if self.model_accepts_loss_kwargs: # if self.model_accepts_loss_kwargs:
loss_kwargs = {} # loss_kwargs = {}
if num_items_in_batch is not None: # if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch # loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs} # inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs) # outputs = model(**inputs)
#
student_logits = outputs["logits"] # student_logits = outputs["logits"]
# Slice or gather student logits to match teacher seq len # # Slice or gather student logits to match teacher seq len
# e.g.: # # e.g.:
teacher_seq_len = target_token_ids.shape[1] # teacher_seq_len = target_token_ids.shape[1]
student_logits_for_kd = student_logits[ # student_logits_for_kd = student_logits[
:, :teacher_seq_len, : # :, :teacher_seq_len, :
] # [B, seq_len, vocab_size] # ] # [B, seq_len, vocab_size]
#
# GATHER top-K from student # # GATHER top-K from student
student_logits_topk = torch.gather( # student_logits_topk = torch.gather(
student_logits_for_kd, # student_logits_for_kd,
dim=-1, # dim=-1,
index=target_token_ids, # same shape [B, seq_len, K] # index=target_token_ids, # same shape [B, seq_len, K]
) # )
#
# Now call the Triton-based KD loss # # Now call the Triton-based KD loss
kd_sum = kd_loss_triton( # kd_sum = kd_loss_triton(
student_logits_topk, # student_logits_topk,
target_logprobs, # teacher logprobs [B, seq_len, K] # target_logprobs, # teacher logprobs [B, seq_len, K]
target_mask, # mask [B, seq_len, K] # target_mask, # mask [B, seq_len, K]
) # )
#
# Normalize however you want # # Normalize however you want
if num_items_in_batch is not None: # if num_items_in_batch is not None:
loss_kd = kd_sum / num_items_in_batch # loss_kd = kd_sum / num_items_in_batch
else: # else:
# or do e.g. average over valid tokens # # or do e.g. average over valid tokens
# quick example: # # quick example:
total_valid = target_mask.sum() # total_valid = target_mask.sum()
loss_kd = kd_sum / (total_valid + 1e-8) # loss_kd = kd_sum / (total_valid + 1e-8)
#
# optionally combine with CE loss # # optionally combine with CE loss
if self.args.kd_ce_alpha > 0: # if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha # kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd # loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else: # else:
loss = loss_kd # loss = loss_kd
#
return (loss, outputs) if return_outputs else loss # return (loss, outputs) if return_outputs else loss
def compute_loss( def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None self, model, inputs, return_outputs=False, num_items_in_batch=None
@@ -156,10 +155,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
Subclass and override for custom behavior. Subclass and override for custom behavior.
""" """
# return self.compute_loss_w_triton(
# model, inputs, return_outputs, num_items_in_batch
# )
target_logprobs = inputs.pop("target_logprobs") target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids") target_token_ids = inputs.pop("target_token_ids")
target_mask = inputs.pop("target_mask") target_mask = inputs.pop("target_mask")

View File

@@ -1,3 +1,7 @@
"""
Triton kernel for optimized kl divergence loss
"""
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
@@ -37,10 +41,10 @@ def kd_forward_kernel(
mask_ptr: tl.tensor, mask_ptr: tl.tensor,
# partial_kd: [B*seq_len] flattened buffer to store partial sums # partial_kd: [B*seq_len] flattened buffer to store partial sums
partial_kd_ptr: tl.tensor, partial_kd_ptr: tl.tensor,
B: tl.int32, B: tl.int32, # pylint: disable=invalid-name
seq_len: tl.int32, seq_len: tl.int32,
K: tl.int32, K: tl.int32, # pylint: disable=invalid-name
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, # pylint: disable=invalid-name
): ):
""" """
For each position in [0..B*seq_len), we: For each position in [0..B*seq_len), we:
@@ -82,11 +86,7 @@ def kd_forward_kernel(
# load student logits, masked out-of-bounds with a large negative # load student logits, masked out-of-bounds with a large negative
# so they don't affect the max # so they don't affect the max
student_val = tl.where( student_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
mask_pos,
tl.load(student_logits_ptr + offset_k),
-1e30
)
# update running max # update running max
max_val = tl.where(student_val > max_val, student_val, max_val) max_val = tl.where(student_val > max_val, student_val, max_val)
@@ -96,11 +96,7 @@ def kd_forward_kernel(
exp_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) exp_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for k in range(K): for k in range(K):
offset_k = b_idx * (seq_len * K) + s_idx * K + k offset_k = b_idx * (seq_len * K) + s_idx * K + k
student_val = tl.where( student_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
mask_pos,
tl.load(student_logits_ptr + offset_k),
-1e30
)
# exponent # exponent
exponent = tl.exp(student_val - max_val) exponent = tl.exp(student_val - max_val)
exp_sum += exponent exp_sum += exponent
@@ -119,20 +115,12 @@ def kd_forward_kernel(
for k in range(K): for k in range(K):
offset_k = b_idx * (seq_len * K) + s_idx * K + k offset_k = b_idx * (seq_len * K) + s_idx * K + k
# teacher logprobs # teacher logprobs
t_log = tl.where( t_log = tl.where(mask_pos, tl.load(teacher_logprobs_ptr + offset_k), -1e30)
mask_pos,
tl.load(teacher_logprobs_ptr + offset_k),
-1e30
)
# teacher prob # teacher prob
t_prob = tl.exp(t_log) t_prob = tl.exp(t_log)
# student logit # student logit
s_val = tl.where( s_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
mask_pos,
tl.load(student_logits_ptr + offset_k),
-1e30
)
# student logprob # student logprob
s_logprob = s_val - logsumexp_val s_logprob = s_val - logsumexp_val
@@ -142,7 +130,7 @@ def kd_forward_kernel(
# also read mask to disable invalid tokens if mask is not purely sequence-based # also read mask to disable invalid tokens if mask is not purely sequence-based
valid_k = tl.load(mask_ptr + offset_k) valid_k = tl.load(mask_ptr + offset_k)
# if mask is bool => use 'valid_k != 0', if it's 0/1 => same # if mask is bool => use 'valid_k != 0', if it's 0/1 => same
is_valid = (valid_k > 0) is_valid = valid_k > 0
# zero out if either this index is out-of-bounds or mask is invalid # zero out if either this index is out-of-bounds or mask is invalid
kl_val = tl.where(mask_pos & is_valid, kl_val, 0.0) kl_val = tl.where(mask_pos & is_valid, kl_val, 0.0)
@@ -158,17 +146,17 @@ def kd_forward_kernel(
def kd_forward_pass_triton( def kd_forward_pass_triton(
student_logits, # [B, seq_len, K] (already gathered) student_logits, # [B, seq_len, K] (already gathered)
teacher_logprobs, # [B, seq_len, K] teacher_logprobs, # [B, seq_len, K]
mask, # [B, seq_len, K] bool or 0/1 mask, # [B, seq_len, K] bool or 0/1
BLOCK_SIZE=1024, BLOCK_SIZE=1024, # pylint: disable=invalid-name
): ):
""" """
Returns total KL (float). We do the sum on the Python side. Returns total KL (float). We do the sum on the Python side.
NOTE: No normalization is done here. NOTE: No normalization is done here.
You might divide by `num_items_in_batch` or # valid tokens afterward. You might divide by `num_items_in_batch` or # valid tokens afterward.
""" """
B, seq_len, K = student_logits.shape B, seq_len, K = student_logits.shape # pylint: disable=invalid-name
# Flatten # Flatten
student_logits_flat = student_logits.reshape(-1) student_logits_flat = student_logits.reshape(-1)
teacher_logprobs_flat = teacher_logprobs.reshape(-1) teacher_logprobs_flat = teacher_logprobs.reshape(-1)
@@ -188,14 +176,17 @@ def kd_forward_pass_triton(
teacher_logprobs_flat, teacher_logprobs_flat,
mask_flat, mask_flat,
partial_kd, partial_kd,
B, seq_len, K, B,
BLOCK_SIZE=BLOCK_SIZE seq_len,
K,
BLOCK_SIZE=BLOCK_SIZE,
) )
# Sum on CPU or GPU # Sum on CPU or GPU
kd_sum = partial_kd.sum() kd_sum = partial_kd.sum()
return kd_sum return kd_sum
class _KLDivergenceTritonFn(torch.autograd.Function): class _KLDivergenceTritonFn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, student_logits, teacher_logprobs, mask): def forward(ctx, student_logits, teacher_logprobs, mask):
@@ -211,7 +202,6 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
ctx.save_for_backward(student_logits, teacher_logprobs, mask) ctx.save_for_backward(student_logits, teacher_logprobs, mask)
return kd_loss return kd_loss
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
# We'll do naive PyTorch re-computation for gradient wrt student_logits # We'll do naive PyTorch re-computation for gradient wrt student_logits
@@ -244,7 +234,7 @@ def kd_loss_triton(
student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K
teacher_logprobs, teacher_logprobs,
mask, mask,
num_items_in_batch=None, num_items_in_batch=None, # pylint: disable=unused-argument
): ):
""" """
Wrapper that calls our Triton-based forward+backward for KD. Wrapper that calls our Triton-based forward+backward for KD.
@@ -253,5 +243,7 @@ def kd_loss_triton(
called gather on student_logits -> shape [B, seq_len, K]. called gather on student_logits -> shape [B, seq_len, K].
""" """
return _KLDivergenceTritonFn.apply( return _KLDivergenceTritonFn.apply(
student_logits, teacher_logprobs, mask, # num_items_in_batch student_logits,
teacher_logprobs,
mask, # num_items_in_batch
) )