This commit is contained in:
Wing Lian
2025-05-19 15:25:15 -07:00
parent 7909bfb076
commit 0e46367e01
2 changed files with 10 additions and 11 deletions

View File

@@ -202,7 +202,7 @@ class AxolotlTrainingMixins:
) )
kd_top_k_before_softmax: Optional[bool] = field( kd_top_k_before_softmax: Optional[bool] = field(
default=None, default=False,
metadata={ metadata={
"help": "Whether to apply top_k_before_softmax to the logits when using KD" "help": "Whether to apply top_k_before_softmax to the logits when using KD"
}, },

View File

@@ -15,12 +15,15 @@
""" """
Chat template prompt strategy loader with KD support Chat template prompt strategy loader with KD support
""" """
import logging
from typing import Any, Dict from typing import Any, Dict
import torch import torch
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
LOG = logging.getLogger(__name__)
class ChatTemplateStrategyWithKD(ChatTemplateStrategy): class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
""" """
@@ -101,10 +104,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# fill with -inf for padding_len tokens for top_k tokens # fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf # extend target_logprobs with a padding_len x top_k 2D list filled with -inf
# for causal models, if we start the range at 1, then we don't need to shift in the trainer # we shift for causal models in the trainer, so start the range from 0
# otherwise, we need to shift in the trainer for _ in range(0, input_padding_len):
shift = 0
for _ in range(shift, input_padding_len):
target_logprobs.append([-float("inf")] * top_k) target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k))) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
@@ -143,6 +144,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# #
# Convert from log to probability # Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp() teacher_probs_t1 = position_logprobs_tensor.exp()
# normalize probabilities to sum to 1 in case they aren't already
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
if teacher_probs_t1_sum > 1e-9:
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
if self.kd_temperature != self.gen_temperature: if self.kd_temperature != self.gen_temperature:
# Exponentiate by factor (T1 / T2) # Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature exponent = self.gen_temperature / self.kd_temperature
@@ -162,12 +167,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_logprobs.append(position_logprobs_scaled) target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids) target_token_ids.append(position_token_ids)
if shift == 1:
# since we started at index 1 for causal, we need one more padding token
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
# Update sample with transformed logprobs # Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids sample["target_token_ids"] = target_token_ids