chore: lint
This commit is contained in:
@@ -7,7 +7,6 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
from axolotl.integrations.kd.kernels.kd import kd_loss_triton
|
|
||||||
|
|
||||||
|
|
||||||
def kd_loss_function(
|
def kd_loss_function(
|
||||||
@@ -93,59 +92,59 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
if columns_to_add:
|
if columns_to_add:
|
||||||
self._signature_columns += columns_to_add
|
self._signature_columns += columns_to_add
|
||||||
|
|
||||||
def compute_loss_w_triton(
|
# def compute_loss_w_triton(
|
||||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
# self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||||
):
|
# ):
|
||||||
target_logprobs = inputs.pop("target_logprobs")
|
# target_logprobs = inputs.pop("target_logprobs")
|
||||||
target_token_ids = inputs.pop("target_token_ids")
|
# target_token_ids = inputs.pop("target_token_ids")
|
||||||
target_mask = inputs.pop("target_mask")
|
# target_mask = inputs.pop("target_mask")
|
||||||
|
#
|
||||||
if self.model_accepts_loss_kwargs:
|
# if self.model_accepts_loss_kwargs:
|
||||||
loss_kwargs = {}
|
# loss_kwargs = {}
|
||||||
if num_items_in_batch is not None:
|
# if num_items_in_batch is not None:
|
||||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
# loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||||
inputs = {**inputs, **loss_kwargs}
|
# inputs = {**inputs, **loss_kwargs}
|
||||||
outputs = model(**inputs)
|
# outputs = model(**inputs)
|
||||||
|
#
|
||||||
student_logits = outputs["logits"]
|
# student_logits = outputs["logits"]
|
||||||
# Slice or gather student logits to match teacher seq len
|
# # Slice or gather student logits to match teacher seq len
|
||||||
# e.g.:
|
# # e.g.:
|
||||||
teacher_seq_len = target_token_ids.shape[1]
|
# teacher_seq_len = target_token_ids.shape[1]
|
||||||
student_logits_for_kd = student_logits[
|
# student_logits_for_kd = student_logits[
|
||||||
:, :teacher_seq_len, :
|
# :, :teacher_seq_len, :
|
||||||
] # [B, seq_len, vocab_size]
|
# ] # [B, seq_len, vocab_size]
|
||||||
|
#
|
||||||
# GATHER top-K from student
|
# # GATHER top-K from student
|
||||||
student_logits_topk = torch.gather(
|
# student_logits_topk = torch.gather(
|
||||||
student_logits_for_kd,
|
# student_logits_for_kd,
|
||||||
dim=-1,
|
# dim=-1,
|
||||||
index=target_token_ids, # same shape [B, seq_len, K]
|
# index=target_token_ids, # same shape [B, seq_len, K]
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
# Now call the Triton-based KD loss
|
# # Now call the Triton-based KD loss
|
||||||
kd_sum = kd_loss_triton(
|
# kd_sum = kd_loss_triton(
|
||||||
student_logits_topk,
|
# student_logits_topk,
|
||||||
target_logprobs, # teacher logprobs [B, seq_len, K]
|
# target_logprobs, # teacher logprobs [B, seq_len, K]
|
||||||
target_mask, # mask [B, seq_len, K]
|
# target_mask, # mask [B, seq_len, K]
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
# Normalize however you want
|
# # Normalize however you want
|
||||||
if num_items_in_batch is not None:
|
# if num_items_in_batch is not None:
|
||||||
loss_kd = kd_sum / num_items_in_batch
|
# loss_kd = kd_sum / num_items_in_batch
|
||||||
else:
|
# else:
|
||||||
# or do e.g. average over valid tokens
|
# # or do e.g. average over valid tokens
|
||||||
# quick example:
|
# # quick example:
|
||||||
total_valid = target_mask.sum()
|
# total_valid = target_mask.sum()
|
||||||
loss_kd = kd_sum / (total_valid + 1e-8)
|
# loss_kd = kd_sum / (total_valid + 1e-8)
|
||||||
|
#
|
||||||
# optionally combine with CE loss
|
# # optionally combine with CE loss
|
||||||
if self.args.kd_ce_alpha > 0:
|
# if self.args.kd_ce_alpha > 0:
|
||||||
kd_alpha = self.args.kd_alpha
|
# kd_alpha = self.args.kd_alpha
|
||||||
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
# loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
||||||
else:
|
# else:
|
||||||
loss = loss_kd
|
# loss = loss_kd
|
||||||
|
#
|
||||||
return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||||
@@ -156,10 +155,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
Subclass and override for custom behavior.
|
Subclass and override for custom behavior.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# return self.compute_loss_w_triton(
|
|
||||||
# model, inputs, return_outputs, num_items_in_batch
|
|
||||||
# )
|
|
||||||
|
|
||||||
target_logprobs = inputs.pop("target_logprobs")
|
target_logprobs = inputs.pop("target_logprobs")
|
||||||
target_token_ids = inputs.pop("target_token_ids")
|
target_token_ids = inputs.pop("target_token_ids")
|
||||||
target_mask = inputs.pop("target_mask")
|
target_mask = inputs.pop("target_mask")
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
"""
|
||||||
|
Triton kernel for optimized kl divergence loss
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@@ -37,10 +41,10 @@ def kd_forward_kernel(
|
|||||||
mask_ptr: tl.tensor,
|
mask_ptr: tl.tensor,
|
||||||
# partial_kd: [B*seq_len] flattened buffer to store partial sums
|
# partial_kd: [B*seq_len] flattened buffer to store partial sums
|
||||||
partial_kd_ptr: tl.tensor,
|
partial_kd_ptr: tl.tensor,
|
||||||
B: tl.int32,
|
B: tl.int32, # pylint: disable=invalid-name
|
||||||
seq_len: tl.int32,
|
seq_len: tl.int32,
|
||||||
K: tl.int32,
|
K: tl.int32, # pylint: disable=invalid-name
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr, # pylint: disable=invalid-name
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
For each position in [0..B*seq_len), we:
|
For each position in [0..B*seq_len), we:
|
||||||
@@ -82,11 +86,7 @@ def kd_forward_kernel(
|
|||||||
|
|
||||||
# load student logits, masked out-of-bounds with a large negative
|
# load student logits, masked out-of-bounds with a large negative
|
||||||
# so they don't affect the max
|
# so they don't affect the max
|
||||||
student_val = tl.where(
|
student_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
|
||||||
mask_pos,
|
|
||||||
tl.load(student_logits_ptr + offset_k),
|
|
||||||
-1e30
|
|
||||||
)
|
|
||||||
# update running max
|
# update running max
|
||||||
max_val = tl.where(student_val > max_val, student_val, max_val)
|
max_val = tl.where(student_val > max_val, student_val, max_val)
|
||||||
|
|
||||||
@@ -96,11 +96,7 @@ def kd_forward_kernel(
|
|||||||
exp_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
exp_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||||
for k in range(K):
|
for k in range(K):
|
||||||
offset_k = b_idx * (seq_len * K) + s_idx * K + k
|
offset_k = b_idx * (seq_len * K) + s_idx * K + k
|
||||||
student_val = tl.where(
|
student_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
|
||||||
mask_pos,
|
|
||||||
tl.load(student_logits_ptr + offset_k),
|
|
||||||
-1e30
|
|
||||||
)
|
|
||||||
# exponent
|
# exponent
|
||||||
exponent = tl.exp(student_val - max_val)
|
exponent = tl.exp(student_val - max_val)
|
||||||
exp_sum += exponent
|
exp_sum += exponent
|
||||||
@@ -119,20 +115,12 @@ def kd_forward_kernel(
|
|||||||
for k in range(K):
|
for k in range(K):
|
||||||
offset_k = b_idx * (seq_len * K) + s_idx * K + k
|
offset_k = b_idx * (seq_len * K) + s_idx * K + k
|
||||||
# teacher logprobs
|
# teacher logprobs
|
||||||
t_log = tl.where(
|
t_log = tl.where(mask_pos, tl.load(teacher_logprobs_ptr + offset_k), -1e30)
|
||||||
mask_pos,
|
|
||||||
tl.load(teacher_logprobs_ptr + offset_k),
|
|
||||||
-1e30
|
|
||||||
)
|
|
||||||
# teacher prob
|
# teacher prob
|
||||||
t_prob = tl.exp(t_log)
|
t_prob = tl.exp(t_log)
|
||||||
|
|
||||||
# student logit
|
# student logit
|
||||||
s_val = tl.where(
|
s_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
|
||||||
mask_pos,
|
|
||||||
tl.load(student_logits_ptr + offset_k),
|
|
||||||
-1e30
|
|
||||||
)
|
|
||||||
# student logprob
|
# student logprob
|
||||||
s_logprob = s_val - logsumexp_val
|
s_logprob = s_val - logsumexp_val
|
||||||
|
|
||||||
@@ -142,7 +130,7 @@ def kd_forward_kernel(
|
|||||||
# also read mask to disable invalid tokens if mask is not purely sequence-based
|
# also read mask to disable invalid tokens if mask is not purely sequence-based
|
||||||
valid_k = tl.load(mask_ptr + offset_k)
|
valid_k = tl.load(mask_ptr + offset_k)
|
||||||
# if mask is bool => use 'valid_k != 0', if it's 0/1 => same
|
# if mask is bool => use 'valid_k != 0', if it's 0/1 => same
|
||||||
is_valid = (valid_k > 0)
|
is_valid = valid_k > 0
|
||||||
|
|
||||||
# zero out if either this index is out-of-bounds or mask is invalid
|
# zero out if either this index is out-of-bounds or mask is invalid
|
||||||
kl_val = tl.where(mask_pos & is_valid, kl_val, 0.0)
|
kl_val = tl.where(mask_pos & is_valid, kl_val, 0.0)
|
||||||
@@ -158,17 +146,17 @@ def kd_forward_kernel(
|
|||||||
|
|
||||||
|
|
||||||
def kd_forward_pass_triton(
|
def kd_forward_pass_triton(
|
||||||
student_logits, # [B, seq_len, K] (already gathered)
|
student_logits, # [B, seq_len, K] (already gathered)
|
||||||
teacher_logprobs, # [B, seq_len, K]
|
teacher_logprobs, # [B, seq_len, K]
|
||||||
mask, # [B, seq_len, K] bool or 0/1
|
mask, # [B, seq_len, K] bool or 0/1
|
||||||
BLOCK_SIZE=1024,
|
BLOCK_SIZE=1024, # pylint: disable=invalid-name
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Returns total KL (float). We do the sum on the Python side.
|
Returns total KL (float). We do the sum on the Python side.
|
||||||
NOTE: No normalization is done here.
|
NOTE: No normalization is done here.
|
||||||
You might divide by `num_items_in_batch` or # valid tokens afterward.
|
You might divide by `num_items_in_batch` or # valid tokens afterward.
|
||||||
"""
|
"""
|
||||||
B, seq_len, K = student_logits.shape
|
B, seq_len, K = student_logits.shape # pylint: disable=invalid-name
|
||||||
# Flatten
|
# Flatten
|
||||||
student_logits_flat = student_logits.reshape(-1)
|
student_logits_flat = student_logits.reshape(-1)
|
||||||
teacher_logprobs_flat = teacher_logprobs.reshape(-1)
|
teacher_logprobs_flat = teacher_logprobs.reshape(-1)
|
||||||
@@ -188,14 +176,17 @@ def kd_forward_pass_triton(
|
|||||||
teacher_logprobs_flat,
|
teacher_logprobs_flat,
|
||||||
mask_flat,
|
mask_flat,
|
||||||
partial_kd,
|
partial_kd,
|
||||||
B, seq_len, K,
|
B,
|
||||||
BLOCK_SIZE=BLOCK_SIZE
|
seq_len,
|
||||||
|
K,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sum on CPU or GPU
|
# Sum on CPU or GPU
|
||||||
kd_sum = partial_kd.sum()
|
kd_sum = partial_kd.sum()
|
||||||
return kd_sum
|
return kd_sum
|
||||||
|
|
||||||
|
|
||||||
class _KLDivergenceTritonFn(torch.autograd.Function):
|
class _KLDivergenceTritonFn(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, student_logits, teacher_logprobs, mask):
|
def forward(ctx, student_logits, teacher_logprobs, mask):
|
||||||
@@ -211,7 +202,6 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
|
|||||||
ctx.save_for_backward(student_logits, teacher_logprobs, mask)
|
ctx.save_for_backward(student_logits, teacher_logprobs, mask)
|
||||||
return kd_loss
|
return kd_loss
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
# We'll do naive PyTorch re-computation for gradient wrt student_logits
|
# We'll do naive PyTorch re-computation for gradient wrt student_logits
|
||||||
@@ -244,7 +234,7 @@ def kd_loss_triton(
|
|||||||
student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K
|
student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K
|
||||||
teacher_logprobs,
|
teacher_logprobs,
|
||||||
mask,
|
mask,
|
||||||
num_items_in_batch=None,
|
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Wrapper that calls our Triton-based forward+backward for KD.
|
Wrapper that calls our Triton-based forward+backward for KD.
|
||||||
@@ -253,5 +243,7 @@ def kd_loss_triton(
|
|||||||
called gather on student_logits -> shape [B, seq_len, K].
|
called gather on student_logits -> shape [B, seq_len, K].
|
||||||
"""
|
"""
|
||||||
return _KLDivergenceTritonFn.apply(
|
return _KLDivergenceTritonFn.apply(
|
||||||
student_logits, teacher_logprobs, mask, # num_items_in_batch
|
student_logits,
|
||||||
|
teacher_logprobs,
|
||||||
|
mask, # num_items_in_batch
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user