temp scale kd loss at end
This commit is contained in:
@@ -129,8 +129,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
|
|
||||||
loss = weight_hard_loss * ce_loss + weight_soft_loss * soft_loss
|
loss = weight_hard_loss * ce_loss + weight_soft_loss * soft_loss
|
||||||
|
|
||||||
# return loss, (soft_loss, ce_loss, student_logits_chunk) # Aux outputs
|
return soft_loss, ce_loss
|
||||||
return loss, (soft_loss, ce_loss) # Aux outputs
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def forward(
|
def forward(
|
||||||
@@ -159,7 +158,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
if student_lm_head_bias is not None
|
if student_lm_head_bias is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
loss_acc = torch.zeros(
|
kd_loss_acc = torch.zeros(
|
||||||
|
(), device=student_input.device, dtype=student_input.dtype
|
||||||
|
)
|
||||||
|
ce_loss_acc = torch.zeros(
|
||||||
(), device=student_input.device, dtype=student_input.dtype
|
(), device=student_input.device, dtype=student_input.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -203,7 +205,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
if student_lm_head_bias is not None:
|
if student_lm_head_bias is not None:
|
||||||
(
|
(
|
||||||
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
||||||
(chunk_loss, _aux_outputs),
|
(chunk_kd_loss, chunk_ce_loss),
|
||||||
) = torch.func.grad_and_value(
|
) = torch.func.grad_and_value(
|
||||||
loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True
|
loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True
|
||||||
)(
|
)(
|
||||||
@@ -220,7 +222,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight
|
argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight
|
||||||
(
|
(
|
||||||
(chunk_grad_input, chunk_grad_weight), # No grad for bias
|
(chunk_grad_input, chunk_grad_weight), # No grad for bias
|
||||||
(chunk_loss, _aux_outputs),
|
(chunk_kd_loss, chunk_ce_loss),
|
||||||
) = torch.func.grad_and_value(
|
) = torch.func.grad_and_value(
|
||||||
loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True
|
loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True
|
||||||
)(
|
)(
|
||||||
@@ -234,7 +236,9 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
grad_weight_acc.add_(chunk_grad_weight)
|
grad_weight_acc.add_(chunk_grad_weight)
|
||||||
loss_acc.add_(chunk_loss)
|
kd_loss_acc.add_(chunk_kd_loss)
|
||||||
|
ce_loss_acc.add_(chunk_ce_loss)
|
||||||
|
|
||||||
return chunk_grad_input
|
return chunk_grad_input
|
||||||
|
|
||||||
if compiled:
|
if compiled:
|
||||||
@@ -288,9 +292,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
ctx.bias_was_none = student_lm_head_bias is None
|
ctx.bias_was_none = student_lm_head_bias is None
|
||||||
ctx.orig_dims = (B, N, D, K)
|
ctx.orig_dims = (B, N, D, K)
|
||||||
|
|
||||||
num_valid_tokens_scalar: float = (true_labels_flat != ignore_index).sum().item()
|
# since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulatedsum
|
||||||
ctx.num_valid_tokens_scalar = num_valid_tokens_scalar
|
# we still need to scale the kd_loss by the temp
|
||||||
final_loss = loss_acc # / ctx.num_valid_tokens_scalar
|
kd_loss_acc = kd_loss_acc * (temperature ** 2)
|
||||||
|
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
|
||||||
|
|
||||||
return final_loss
|
return final_loss
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user