Always re-normalize teacher distribution
This commit is contained in:
@@ -126,28 +126,28 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
position_token_ids.append(token_id)
|
position_token_ids.append(token_id)
|
||||||
|
|
||||||
# Convert to a tensor for easier manipulation
|
# Convert to a tensor for easier manipulation
|
||||||
# Convert to tensor
|
|
||||||
position_logprobs_tensor = torch.tensor(
|
position_logprobs_tensor = torch.tensor(
|
||||||
position_logprobs, dtype=torch.float
|
position_logprobs, dtype=torch.float
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
||||||
|
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
||||||
|
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
||||||
|
#
|
||||||
|
# Convert from log to probability
|
||||||
|
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||||
if self.kd_temperature != self.gen_temperature:
|
if self.kd_temperature != self.gen_temperature:
|
||||||
#
|
|
||||||
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
|
||||||
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
|
||||||
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
|
||||||
#
|
|
||||||
# Convert from log to probability
|
|
||||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
|
||||||
# Exponentiate by factor (T1 / T2)
|
# Exponentiate by factor (T1 / T2)
|
||||||
exponent = self.gen_temperature / self.kd_temperature
|
exponent = self.gen_temperature / self.kd_temperature
|
||||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||||
# Re-normalize
|
else:
|
||||||
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
teacher_probs_t2 = teacher_probs_t1
|
||||||
dim=0, keepdim=True
|
# Re-normalize
|
||||||
)
|
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||||
# Convert back to log
|
dim=0, keepdim=True
|
||||||
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
)
|
||||||
|
# Convert back to log
|
||||||
|
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
||||||
|
|
||||||
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
||||||
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||||
|
|||||||
Reference in New Issue
Block a user