kd loss needs to be calculated in full precision
This commit is contained in:
@@ -63,6 +63,8 @@ def loss(
|
||||
A KD loss function that is TorchScript-friendly.
|
||||
"""
|
||||
|
||||
target_logprobs = target_logprobs.float()
|
||||
|
||||
# Determine the teacher sequence length
|
||||
# target_token_ids shape: [B, teacher_seq_len, K]
|
||||
# student_logits shape: [B, student_seq_len, vocab_size]
|
||||
@@ -78,6 +80,8 @@ def loss(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
|
||||
student_logits_topk = student_logits_topk.float()
|
||||
|
||||
# Apply KD temperature to student’s logits
|
||||
if kd_temperature != 1.0:
|
||||
student_logits_topk = student_logits_topk / kd_temperature
|
||||
@@ -130,6 +134,8 @@ def topk_kd_loss_with_zscore(
|
||||
from "Logit Standardization in Knowledge Distillation".
|
||||
"""
|
||||
|
||||
teacher_topk_logprobs = teacher_topk_logprobs.float()
|
||||
|
||||
B, teacher_seq_len, K = teacher_topk_logprobs.shape # pylint: disable=invalid-name
|
||||
# 1) Gather the student's top-k logits to match teacher
|
||||
student_logits_for_kd = student_logits[
|
||||
@@ -139,6 +145,8 @@ def topk_kd_loss_with_zscore(
|
||||
student_logits_for_kd, dim=-1, index=teacher_topk_ids
|
||||
) # [B, seq_len, K]
|
||||
|
||||
student_topk_logits = student_topk_logits.float()
|
||||
|
||||
# 2) If you want to keep the "classical" T scaling, apply it first
|
||||
if kd_temperature != 1.0:
|
||||
student_topk_logits = student_topk_logits / kd_temperature
|
||||
|
||||
Reference in New Issue
Block a user