diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py index 787740bcf..f30904d5a 100644 --- a/src/axolotl/integrations/kd/kernels/liger.py +++ b/src/axolotl/integrations/kd/kernels/liger.py @@ -129,8 +129,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): loss = weight_hard_loss * ce_loss + weight_soft_loss * soft_loss - # return loss, (soft_loss, ce_loss, student_logits_chunk) # Aux outputs - return loss, (soft_loss, ce_loss) # Aux outputs + return soft_loss, ce_loss @classmethod def forward( @@ -159,7 +158,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): if student_lm_head_bias is not None else None ) - loss_acc = torch.zeros( + kd_loss_acc = torch.zeros( + (), device=student_input.device, dtype=student_input.dtype + ) + ce_loss_acc = torch.zeros( (), device=student_input.device, dtype=student_input.dtype ) @@ -203,7 +205,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): if student_lm_head_bias is not None: ( (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), - (chunk_loss, _aux_outputs), + (chunk_kd_loss, chunk_ce_loss), ) = torch.func.grad_and_value( loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True )( @@ -220,7 +222,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight ( (chunk_grad_input, chunk_grad_weight), # No grad for bias - (chunk_loss, _aux_outputs), + (chunk_kd_loss, chunk_ce_loss), ) = torch.func.grad_and_value( loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True )( @@ -234,7 +236,9 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): ) grad_weight_acc.add_(chunk_grad_weight) - loss_acc.add_(chunk_loss) + kd_loss_acc.add_(chunk_kd_loss) + ce_loss_acc.add_(chunk_ce_loss) + return chunk_grad_input if compiled: @@ -288,9 +292,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): ctx.bias_was_none = student_lm_head_bias is None ctx.orig_dims = (B, N, D, K) - num_valid_tokens_scalar: float = (true_labels_flat != ignore_index).sum().item() - ctx.num_valid_tokens_scalar = num_valid_tokens_scalar - final_loss = loss_acc # / ctx.num_valid_tokens_scalar + # since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulatedsum + # we still need to scale the kd_loss by the temp + kd_loss_acc = kd_loss_acc * (temperature ** 2) + final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc return final_loss