temp scale kd loss at end

This commit is contained in:
Wing Lian
2025-05-26 23:52:29 -04:00
parent 90c7228ff9
commit 24b96b1c4f

View File

@@ -129,8 +129,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
loss = weight_hard_loss * ce_loss + weight_soft_loss * soft_loss
# return loss, (soft_loss, ce_loss, student_logits_chunk) # Aux outputs
return loss, (soft_loss, ce_loss) # Aux outputs
return soft_loss, ce_loss
@classmethod
def forward(
@@ -159,7 +158,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
if student_lm_head_bias is not None
else None
)
loss_acc = torch.zeros(
kd_loss_acc = torch.zeros(
(), device=student_input.device, dtype=student_input.dtype
)
ce_loss_acc = torch.zeros(
(), device=student_input.device, dtype=student_input.dtype
)
@@ -203,7 +205,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
if student_lm_head_bias is not None:
(
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
(chunk_loss, _aux_outputs),
(chunk_kd_loss, chunk_ce_loss),
) = torch.func.grad_and_value(
loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True
)(
@@ -220,7 +222,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight
(
(chunk_grad_input, chunk_grad_weight), # No grad for bias
(chunk_loss, _aux_outputs),
(chunk_kd_loss, chunk_ce_loss),
) = torch.func.grad_and_value(
loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True
)(
@@ -234,7 +236,9 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
)
grad_weight_acc.add_(chunk_grad_weight)
loss_acc.add_(chunk_loss)
kd_loss_acc.add_(chunk_kd_loss)
ce_loss_acc.add_(chunk_ce_loss)
return chunk_grad_input
if compiled:
@@ -288,9 +292,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
ctx.bias_was_none = student_lm_head_bias is None
ctx.orig_dims = (B, N, D, K)
num_valid_tokens_scalar: float = (true_labels_flat != ignore_index).sum().item()
ctx.num_valid_tokens_scalar = num_valid_tokens_scalar
final_loss = loss_acc # / ctx.num_valid_tokens_scalar
# since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulatedsum
# we still need to scale the kd_loss by the temp
kd_loss_acc = kd_loss_acc * (temperature ** 2)
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
return final_loss