diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py index 74c97897b..6356643c2 100644 --- a/src/axolotl/integrations/kd/kernels/liger.py +++ b/src/axolotl/integrations/kd/kernels/liger.py @@ -70,11 +70,11 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k] student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask] - target_logprobs_valid = target_logprobs_chunk[valid_mask] + teacher_logprobs_valid = target_logprobs_chunk[valid_mask] # Teacher probabilities P(y|x_teacher) from logprobs # target_logprobs_valid are already normalized (log(softmax(teacher_logits/T))) - teacher_probs_valid = target_logprobs_valid.exp() + teacher_probs_valid = teacher_logprobs_valid.exp() # Student probabilities P_student from log P_student student_probs_topk_valid = student_logprobs_topk_valid.exp() @@ -88,12 +88,12 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): # student_logprobs_topk_valid are log_softmax_student (for the selected K indices). if beta == 0.0: # Contribution from Forward KL fwd_kl_per_token = teacher_probs_valid * ( - target_logprobs_valid - student_logprobs_topk_valid + teacher_logprobs_valid - student_logprobs_topk_valid ) kd_loss = fwd_kl_per_token.sum() elif beta == 1.0: # Contribution from Reverse KL rev_kl_per_token = student_probs_topk_valid * ( - student_logprobs_topk_valid - target_logprobs_valid + student_logprobs_topk_valid - teacher_logprobs_valid ) kd_loss = rev_kl_per_token.sum() else: @@ -109,7 +109,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): log_target=True, ) teacher_kl = F.kl_div( - log_mean_probs, target_logprobs_valid, reduction="sum", log_target=True + log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True ) jsd_loss = beta * teacher_kl + (1 - beta) * student_kl kd_loss = jsd_loss