kd fixes
This commit is contained in:
@@ -202,7 +202,7 @@ class AxolotlTrainingMixins:
|
||||
)
|
||||
|
||||
kd_top_k_before_softmax: Optional[bool] = field(
|
||||
default=None,
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
||||
},
|
||||
|
||||
@@ -15,12 +15,15 @@
|
||||
"""
|
||||
Chat template prompt strategy loader with KD support
|
||||
"""
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
"""
|
||||
@@ -101,10 +104,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
# fill with -inf for padding_len tokens for top_k tokens
|
||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||
|
||||
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
||||
# otherwise, we need to shift in the trainer
|
||||
shift = 0
|
||||
for _ in range(shift, input_padding_len):
|
||||
# we shift for causal models in the trainer, so start the range from 0
|
||||
for _ in range(0, input_padding_len):
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
@@ -143,6 +144,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
#
|
||||
# Convert from log to probability
|
||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||
# normalize probabilities to sum to 1 in case they aren't already
|
||||
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
|
||||
if teacher_probs_t1_sum > 1e-9:
|
||||
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
|
||||
if self.kd_temperature != self.gen_temperature:
|
||||
# Exponentiate by factor (T1 / T2)
|
||||
exponent = self.gen_temperature / self.kd_temperature
|
||||
@@ -162,12 +167,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
target_token_ids.append(position_token_ids)
|
||||
|
||||
if shift == 1:
|
||||
# since we started at index 1 for causal, we need one more padding token
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
# Update sample with transformed logprobs
|
||||
sample["target_logprobs"] = target_logprobs
|
||||
sample["target_token_ids"] = target_token_ids
|
||||
|
||||
Reference in New Issue
Block a user