kd fixes
This commit is contained in:
@@ -202,7 +202,7 @@ class AxolotlTrainingMixins:
|
|||||||
)
|
)
|
||||||
|
|
||||||
kd_top_k_before_softmax: Optional[bool] = field(
|
kd_top_k_before_softmax: Optional[bool] = field(
|
||||||
default=None,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -15,12 +15,15 @@
|
|||||||
"""
|
"""
|
||||||
Chat template prompt strategy loader with KD support
|
Chat template prompt strategy loader with KD support
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
|
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||||
"""
|
"""
|
||||||
@@ -101,10 +104,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
# fill with -inf for padding_len tokens for top_k tokens
|
# fill with -inf for padding_len tokens for top_k tokens
|
||||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||||
|
|
||||||
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
# we shift for causal models in the trainer, so start the range from 0
|
||||||
# otherwise, we need to shift in the trainer
|
for _ in range(0, input_padding_len):
|
||||||
shift = 0
|
|
||||||
for _ in range(shift, input_padding_len):
|
|
||||||
target_logprobs.append([-float("inf")] * top_k)
|
target_logprobs.append([-float("inf")] * top_k)
|
||||||
target_token_ids.append(list(range(top_k)))
|
target_token_ids.append(list(range(top_k)))
|
||||||
target_mask.append([0] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
@@ -143,6 +144,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
#
|
#
|
||||||
# Convert from log to probability
|
# Convert from log to probability
|
||||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||||
|
# normalize probabilities to sum to 1 in case they aren't already
|
||||||
|
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
|
||||||
|
if teacher_probs_t1_sum > 1e-9:
|
||||||
|
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
|
||||||
if self.kd_temperature != self.gen_temperature:
|
if self.kd_temperature != self.gen_temperature:
|
||||||
# Exponentiate by factor (T1 / T2)
|
# Exponentiate by factor (T1 / T2)
|
||||||
exponent = self.gen_temperature / self.kd_temperature
|
exponent = self.gen_temperature / self.kd_temperature
|
||||||
@@ -162,12 +167,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
target_logprobs.append(position_logprobs_scaled)
|
target_logprobs.append(position_logprobs_scaled)
|
||||||
target_token_ids.append(position_token_ids)
|
target_token_ids.append(position_token_ids)
|
||||||
|
|
||||||
if shift == 1:
|
|
||||||
# since we started at index 1 for causal, we need one more padding token
|
|
||||||
target_logprobs.append([-float("inf")] * top_k)
|
|
||||||
target_token_ids.append(list(range(top_k)))
|
|
||||||
target_mask.append([0] * top_k)
|
|
||||||
|
|
||||||
# Update sample with transformed logprobs
|
# Update sample with transformed logprobs
|
||||||
sample["target_logprobs"] = target_logprobs
|
sample["target_logprobs"] = target_logprobs
|
||||||
sample["target_token_ids"] = target_token_ids
|
sample["target_token_ids"] = target_token_ids
|
||||||
|
|||||||
Reference in New Issue
Block a user