diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 42488e643..0cba7cf94 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -202,7 +202,7 @@ class AxolotlTrainingMixins: ) kd_top_k_before_softmax: Optional[bool] = field( - default=None, + default=False, metadata={ "help": "Whether to apply top_k_before_softmax to the logits when using KD" }, diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index eb067cd04..43a90216e 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -15,12 +15,15 @@ """ Chat template prompt strategy loader with KD support """ +import logging from typing import Any, Dict import torch from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader +LOG = logging.getLogger(__name__) + class ChatTemplateStrategyWithKD(ChatTemplateStrategy): """ @@ -101,10 +104,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): # fill with -inf for padding_len tokens for top_k tokens # extend target_logprobs with a padding_len x top_k 2D list filled with -inf - # for causal models, if we start the range at 1, then we don't need to shift in the trainer - # otherwise, we need to shift in the trainer - shift = 0 - for _ in range(shift, input_padding_len): + # we shift for causal models in the trainer, so start the range from 0 + for _ in range(0, input_padding_len): target_logprobs.append([-float("inf")] * top_k) target_token_ids.append(list(range(top_k))) target_mask.append([0] * top_k) @@ -143,6 +144,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): # # Convert from log to probability teacher_probs_t1 = position_logprobs_tensor.exp() + # normalize probabilities to sum to 1 in case they aren't already + teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True) + if teacher_probs_t1_sum > 1e-9: + teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum if self.kd_temperature != self.gen_temperature: # Exponentiate by factor (T1 / T2) exponent = self.gen_temperature / self.kd_temperature @@ -162,12 +167,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): target_logprobs.append(position_logprobs_scaled) target_token_ids.append(position_token_ids) - if shift == 1: - # since we started at index 1 for causal, we need one more padding token - target_logprobs.append([-float("inf")] * top_k) - target_token_ids.append(list(range(top_k))) - target_mask.append([0] * top_k) - # Update sample with transformed logprobs sample["target_logprobs"] = target_logprobs sample["target_token_ids"] = target_token_ids