temp scale kd loss at end
This commit is contained in:
@@ -129,8 +129,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
|
||||
loss = weight_hard_loss * ce_loss + weight_soft_loss * soft_loss
|
||||
|
||||
# return loss, (soft_loss, ce_loss, student_logits_chunk) # Aux outputs
|
||||
return loss, (soft_loss, ce_loss) # Aux outputs
|
||||
return soft_loss, ce_loss
|
||||
|
||||
@classmethod
|
||||
def forward(
|
||||
@@ -159,7 +158,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
if student_lm_head_bias is not None
|
||||
else None
|
||||
)
|
||||
loss_acc = torch.zeros(
|
||||
kd_loss_acc = torch.zeros(
|
||||
(), device=student_input.device, dtype=student_input.dtype
|
||||
)
|
||||
ce_loss_acc = torch.zeros(
|
||||
(), device=student_input.device, dtype=student_input.dtype
|
||||
)
|
||||
|
||||
@@ -203,7 +205,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
if student_lm_head_bias is not None:
|
||||
(
|
||||
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
||||
(chunk_loss, _aux_outputs),
|
||||
(chunk_kd_loss, chunk_ce_loss),
|
||||
) = torch.func.grad_and_value(
|
||||
loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True
|
||||
)(
|
||||
@@ -220,7 +222,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight
|
||||
(
|
||||
(chunk_grad_input, chunk_grad_weight), # No grad for bias
|
||||
(chunk_loss, _aux_outputs),
|
||||
(chunk_kd_loss, chunk_ce_loss),
|
||||
) = torch.func.grad_and_value(
|
||||
loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True
|
||||
)(
|
||||
@@ -234,7 +236,9 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
)
|
||||
|
||||
grad_weight_acc.add_(chunk_grad_weight)
|
||||
loss_acc.add_(chunk_loss)
|
||||
kd_loss_acc.add_(chunk_kd_loss)
|
||||
ce_loss_acc.add_(chunk_ce_loss)
|
||||
|
||||
return chunk_grad_input
|
||||
|
||||
if compiled:
|
||||
@@ -288,9 +292,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
ctx.bias_was_none = student_lm_head_bias is None
|
||||
ctx.orig_dims = (B, N, D, K)
|
||||
|
||||
num_valid_tokens_scalar: float = (true_labels_flat != ignore_index).sum().item()
|
||||
ctx.num_valid_tokens_scalar = num_valid_tokens_scalar
|
||||
final_loss = loss_acc # / ctx.num_valid_tokens_scalar
|
||||
# since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulatedsum
|
||||
# we still need to scale the kd_loss by the temp
|
||||
kd_loss_acc = kd_loss_acc * (temperature ** 2)
|
||||
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
|
||||
|
||||
return final_loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user