rename vars for consistency
This commit is contained in:
@@ -70,11 +70,11 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
|
||||
|
||||
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
|
||||
target_logprobs_valid = target_logprobs_chunk[valid_mask]
|
||||
teacher_logprobs_valid = target_logprobs_chunk[valid_mask]
|
||||
|
||||
# Teacher probabilities P(y|x_teacher) from logprobs
|
||||
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
|
||||
teacher_probs_valid = target_logprobs_valid.exp()
|
||||
teacher_probs_valid = teacher_logprobs_valid.exp()
|
||||
# Student probabilities P_student from log P_student
|
||||
student_probs_topk_valid = student_logprobs_topk_valid.exp()
|
||||
|
||||
@@ -88,12 +88,12 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
|
||||
if beta == 0.0: # Contribution from Forward KL
|
||||
fwd_kl_per_token = teacher_probs_valid * (
|
||||
target_logprobs_valid - student_logprobs_topk_valid
|
||||
teacher_logprobs_valid - student_logprobs_topk_valid
|
||||
)
|
||||
kd_loss = fwd_kl_per_token.sum()
|
||||
elif beta == 1.0: # Contribution from Reverse KL
|
||||
rev_kl_per_token = student_probs_topk_valid * (
|
||||
student_logprobs_topk_valid - target_logprobs_valid
|
||||
student_logprobs_topk_valid - teacher_logprobs_valid
|
||||
)
|
||||
kd_loss = rev_kl_per_token.sum()
|
||||
else:
|
||||
@@ -109,7 +109,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
log_target=True,
|
||||
)
|
||||
teacher_kl = F.kl_div(
|
||||
log_mean_probs, target_logprobs_valid, reduction="sum", log_target=True
|
||||
log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True
|
||||
)
|
||||
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
||||
kd_loss = jsd_loss
|
||||
|
||||
Reference in New Issue
Block a user