rename vars for consistency
This commit is contained in:
@@ -70,11 +70,11 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
|
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
|
||||||
|
|
||||||
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
|
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
|
||||||
target_logprobs_valid = target_logprobs_chunk[valid_mask]
|
teacher_logprobs_valid = target_logprobs_chunk[valid_mask]
|
||||||
|
|
||||||
# Teacher probabilities P(y|x_teacher) from logprobs
|
# Teacher probabilities P(y|x_teacher) from logprobs
|
||||||
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
|
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
|
||||||
teacher_probs_valid = target_logprobs_valid.exp()
|
teacher_probs_valid = teacher_logprobs_valid.exp()
|
||||||
# Student probabilities P_student from log P_student
|
# Student probabilities P_student from log P_student
|
||||||
student_probs_topk_valid = student_logprobs_topk_valid.exp()
|
student_probs_topk_valid = student_logprobs_topk_valid.exp()
|
||||||
|
|
||||||
@@ -88,12 +88,12 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
|
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
|
||||||
if beta == 0.0: # Contribution from Forward KL
|
if beta == 0.0: # Contribution from Forward KL
|
||||||
fwd_kl_per_token = teacher_probs_valid * (
|
fwd_kl_per_token = teacher_probs_valid * (
|
||||||
target_logprobs_valid - student_logprobs_topk_valid
|
teacher_logprobs_valid - student_logprobs_topk_valid
|
||||||
)
|
)
|
||||||
kd_loss = fwd_kl_per_token.sum()
|
kd_loss = fwd_kl_per_token.sum()
|
||||||
elif beta == 1.0: # Contribution from Reverse KL
|
elif beta == 1.0: # Contribution from Reverse KL
|
||||||
rev_kl_per_token = student_probs_topk_valid * (
|
rev_kl_per_token = student_probs_topk_valid * (
|
||||||
student_logprobs_topk_valid - target_logprobs_valid
|
student_logprobs_topk_valid - teacher_logprobs_valid
|
||||||
)
|
)
|
||||||
kd_loss = rev_kl_per_token.sum()
|
kd_loss = rev_kl_per_token.sum()
|
||||||
else:
|
else:
|
||||||
@@ -109,7 +109,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
log_target=True,
|
log_target=True,
|
||||||
)
|
)
|
||||||
teacher_kl = F.kl_div(
|
teacher_kl = F.kl_div(
|
||||||
log_mean_probs, target_logprobs_valid, reduction="sum", log_target=True
|
log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True
|
||||||
)
|
)
|
||||||
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
||||||
kd_loss = jsd_loss
|
kd_loss = jsd_loss
|
||||||
|
|||||||
Reference in New Issue
Block a user