rename vars for consistency

This commit is contained in:
Wing Lian
2025-06-04 11:58:20 -07:00
parent 2302b14a84
commit cfcd69df0d

View File

@@ -70,11 +70,11 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
target_logprobs_valid = target_logprobs_chunk[valid_mask]
teacher_logprobs_valid = target_logprobs_chunk[valid_mask]
# Teacher probabilities P(y|x_teacher) from logprobs
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
teacher_probs_valid = target_logprobs_valid.exp()
teacher_probs_valid = teacher_logprobs_valid.exp()
# Student probabilities P_student from log P_student
student_probs_topk_valid = student_logprobs_topk_valid.exp()
@@ -88,12 +88,12 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
if beta == 0.0: # Contribution from Forward KL
fwd_kl_per_token = teacher_probs_valid * (
target_logprobs_valid - student_logprobs_topk_valid
teacher_logprobs_valid - student_logprobs_topk_valid
)
kd_loss = fwd_kl_per_token.sum()
elif beta == 1.0: # Contribution from Reverse KL
rev_kl_per_token = student_probs_topk_valid * (
student_logprobs_topk_valid - target_logprobs_valid
student_logprobs_topk_valid - teacher_logprobs_valid
)
kd_loss = rev_kl_per_token.sum()
else:
@@ -109,7 +109,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
log_target=True,
)
teacher_kl = F.kl_div(
log_mean_probs, target_logprobs_valid, reduction="sum", log_target=True
log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True
)
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
kd_loss = jsd_loss