Always re-normalize teacher distribution

This commit is contained in:
Wing Lian
2025-01-29 08:36:40 -05:00
parent 42d4732aaf
commit c434951dd6

View File

@@ -126,22 +126,22 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
position_token_ids.append(token_id) position_token_ids.append(token_id)
# Convert to a tensor for easier manipulation # Convert to a tensor for easier manipulation
# Convert to tensor
position_logprobs_tensor = torch.tensor( position_logprobs_tensor = torch.tensor(
position_logprobs, dtype=torch.float position_logprobs, dtype=torch.float
) )
if self.kd_temperature != self.gen_temperature:
#
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k). # Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick # Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z # p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
# #
# Convert from log to probability # Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp() teacher_probs_t1 = position_logprobs_tensor.exp()
if self.kd_temperature != self.gen_temperature:
# Exponentiate by factor (T1 / T2) # Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent teacher_probs_t2 = teacher_probs_t1**exponent
else:
teacher_probs_t2 = teacher_probs_t1
# Re-normalize # Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True dim=0, keepdim=True