kd loss needs to be calculated in full precision

This commit is contained in:
Wing Lian
2025-01-28 19:40:35 -05:00
parent 2c9dfbed2e
commit 42d4732aaf

View File

@@ -63,6 +63,8 @@ def loss(
A KD loss function that is TorchScript-friendly.
"""
target_logprobs = target_logprobs.float()
# Determine the teacher sequence length
# target_token_ids shape: [B, teacher_seq_len, K]
# student_logits shape: [B, student_seq_len, vocab_size]
@@ -78,6 +80,8 @@ def loss(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
student_logits_topk = student_logits_topk.float()
# Apply KD temperature to students logits
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature
@@ -130,6 +134,8 @@ def topk_kd_loss_with_zscore(
from "Logit Standardization in Knowledge Distillation".
"""
teacher_topk_logprobs = teacher_topk_logprobs.float()
B, teacher_seq_len, K = teacher_topk_logprobs.shape # pylint: disable=invalid-name
# 1) Gather the student's top-k logits to match teacher
student_logits_for_kd = student_logits[
@@ -139,6 +145,8 @@ def topk_kd_loss_with_zscore(
student_logits_for_kd, dim=-1, index=teacher_topk_ids
) # [B, seq_len, K]
student_topk_logits = student_topk_logits.float()
# 2) If you want to keep the "classical" T scaling, apply it first
if kd_temperature != 1.0:
student_topk_logits = student_topk_logits / kd_temperature