kd loss needs to be calculated in full precision
This commit is contained in:
@@ -63,6 +63,8 @@ def loss(
|
|||||||
A KD loss function that is TorchScript-friendly.
|
A KD loss function that is TorchScript-friendly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
target_logprobs = target_logprobs.float()
|
||||||
|
|
||||||
# Determine the teacher sequence length
|
# Determine the teacher sequence length
|
||||||
# target_token_ids shape: [B, teacher_seq_len, K]
|
# target_token_ids shape: [B, teacher_seq_len, K]
|
||||||
# student_logits shape: [B, student_seq_len, vocab_size]
|
# student_logits shape: [B, student_seq_len, vocab_size]
|
||||||
@@ -78,6 +80,8 @@ def loss(
|
|||||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||||
) # [B, teacher_seq_len, K]
|
) # [B, teacher_seq_len, K]
|
||||||
|
|
||||||
|
student_logits_topk = student_logits_topk.float()
|
||||||
|
|
||||||
# Apply KD temperature to student’s logits
|
# Apply KD temperature to student’s logits
|
||||||
if kd_temperature != 1.0:
|
if kd_temperature != 1.0:
|
||||||
student_logits_topk = student_logits_topk / kd_temperature
|
student_logits_topk = student_logits_topk / kd_temperature
|
||||||
@@ -130,6 +134,8 @@ def topk_kd_loss_with_zscore(
|
|||||||
from "Logit Standardization in Knowledge Distillation".
|
from "Logit Standardization in Knowledge Distillation".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
teacher_topk_logprobs = teacher_topk_logprobs.float()
|
||||||
|
|
||||||
B, teacher_seq_len, K = teacher_topk_logprobs.shape # pylint: disable=invalid-name
|
B, teacher_seq_len, K = teacher_topk_logprobs.shape # pylint: disable=invalid-name
|
||||||
# 1) Gather the student's top-k logits to match teacher
|
# 1) Gather the student's top-k logits to match teacher
|
||||||
student_logits_for_kd = student_logits[
|
student_logits_for_kd = student_logits[
|
||||||
@@ -139,6 +145,8 @@ def topk_kd_loss_with_zscore(
|
|||||||
student_logits_for_kd, dim=-1, index=teacher_topk_ids
|
student_logits_for_kd, dim=-1, index=teacher_topk_ids
|
||||||
) # [B, seq_len, K]
|
) # [B, seq_len, K]
|
||||||
|
|
||||||
|
student_topk_logits = student_topk_logits.float()
|
||||||
|
|
||||||
# 2) If you want to keep the "classical" T scaling, apply it first
|
# 2) If you want to keep the "classical" T scaling, apply it first
|
||||||
if kd_temperature != 1.0:
|
if kd_temperature != 1.0:
|
||||||
student_topk_logits = student_topk_logits / kd_temperature
|
student_topk_logits = student_topk_logits / kd_temperature
|
||||||
|
|||||||
Reference in New Issue
Block a user