From 42d4732aafe7a33c99d00d66c96f1fb9e86c24f4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 28 Jan 2025 19:40:35 -0500 Subject: [PATCH] kd loss needs to be calculated in full precision --- src/axolotl/integrations/kd/topk_logprob/forward_kl.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py index 6a1c80411..a61011ad5 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py @@ -63,6 +63,8 @@ def loss( A KD loss function that is TorchScript-friendly. """ + target_logprobs = target_logprobs.float() + # Determine the teacher sequence length # target_token_ids shape: [B, teacher_seq_len, K] # student_logits shape: [B, student_seq_len, vocab_size] @@ -78,6 +80,8 @@ def loss( student_logits_for_kd, dim=-1, index=target_token_ids ) # [B, teacher_seq_len, K] + student_logits_topk = student_logits_topk.float() + # Apply KD temperature to student’s logits if kd_temperature != 1.0: student_logits_topk = student_logits_topk / kd_temperature @@ -130,6 +134,8 @@ def topk_kd_loss_with_zscore( from "Logit Standardization in Knowledge Distillation". """ + teacher_topk_logprobs = teacher_topk_logprobs.float() + B, teacher_seq_len, K = teacher_topk_logprobs.shape # pylint: disable=invalid-name # 1) Gather the student's top-k logits to match teacher student_logits_for_kd = student_logits[ @@ -139,6 +145,8 @@ def topk_kd_loss_with_zscore( student_logits_for_kd, dim=-1, index=teacher_topk_ids ) # [B, seq_len, K] + student_topk_logits = student_topk_logits.float() + # 2) If you want to keep the "classical" T scaling, apply it first if kd_temperature != 1.0: student_topk_logits = student_topk_logits / kd_temperature