From c434951dd6da9c04b5334eba3e63b98c064c026c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 29 Jan 2025 08:36:40 -0500 Subject: [PATCH] Always re-normalize teacher distribution --- src/axolotl/integrations/kd/chat_template.py | 28 ++++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index 671f8e5c8..5a7e4f40d 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -126,28 +126,28 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): position_token_ids.append(token_id) # Convert to a tensor for easier manipulation - # Convert to tensor position_logprobs_tensor = torch.tensor( position_logprobs, dtype=torch.float ) + # Now we have distribution at T1 in log form, i.e. log p_{T1}(k). + # Next, re-scale to T2 = self.kd_temperature via exponent-based trick + # p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z + # + # Convert from log to probability + teacher_probs_t1 = position_logprobs_tensor.exp() if self.kd_temperature != self.gen_temperature: - # - # Now we have distribution at T1 in log form, i.e. log p_{T1}(k). - # Next, re-scale to T2 = self.kd_temperature via exponent-based trick - # p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z - # - # Convert from log to probability - teacher_probs_t1 = position_logprobs_tensor.exp() # Exponentiate by factor (T1 / T2) exponent = self.gen_temperature / self.kd_temperature teacher_probs_t2 = teacher_probs_t1**exponent - # Re-normalize - teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( - dim=0, keepdim=True - ) - # Convert back to log - position_logprobs_tensor = torch.log(teacher_probs_t2) + else: + teacher_probs_t2 = teacher_probs_t1 + # Re-normalize + teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( + dim=0, keepdim=True + ) + # Convert back to log + position_logprobs_tensor = torch.log(teacher_probs_t2) # Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor position_logprobs_scaled = position_logprobs_tensor.tolist()