better rescaling for temperatures
This commit is contained in:
@@ -680,6 +680,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.kd_ce_alpha is not None:
|
||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||
if self.cfg.kd_alpha is not None:
|
||||
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
||||
|
||||
training_args_cls = (
|
||||
AxolotlTrainingArguments
|
||||
|
||||
@@ -16,6 +16,7 @@ def kd_loss_function(
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
kd_temperature: float = 1.0,
|
||||
):
|
||||
# teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding
|
||||
|
||||
@@ -28,9 +29,15 @@ def kd_loss_function(
|
||||
] # [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
# shape -> [B, teacher_seq_len, K]
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
)
|
||||
|
||||
# Apply KD temperature to student’s logits:
|
||||
# z_s(T) = z_s / T
|
||||
if kd_temperature != 1.0:
|
||||
student_logits_topk = student_logits_topk / kd_temperature
|
||||
|
||||
# Convert student top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
||||
@@ -53,6 +60,10 @@ def kd_loss_function(
|
||||
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
|
||||
kd_loss = kd_loss_per_token.sum()
|
||||
|
||||
# 9) Multiply by T^2 (classical KD scaling)
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
|
||||
# Normalize by number of items or mean over valid tokens
|
||||
if num_items_in_batch is not None:
|
||||
# If you know how many items should be considered in the batch
|
||||
@@ -129,7 +140,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
|
||||
# optionally combine with CE loss
|
||||
if self.args.kd_ce_alpha > 0:
|
||||
loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd
|
||||
kd_alpha = self.args.kd_alpha
|
||||
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
||||
else:
|
||||
loss = loss_kd
|
||||
|
||||
@@ -166,6 +178,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
kd_temperature=self.args.kd_temperature,
|
||||
)
|
||||
|
||||
if self.args.kd_ce_alpha > 0:
|
||||
|
||||
@@ -172,7 +172,19 @@ class AxolotlTrainingMixins:
|
||||
kd_ce_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The alpha parameter for SFT cross entropy loss when using KD"
|
||||
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_alpha: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The alpha scaling parameter for KD loss"},
|
||||
)
|
||||
|
||||
kd_temperature: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "the temperature parameter for KL divergence loss when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -474,10 +474,12 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
roles_to_train=None,
|
||||
train_on_eos=None,
|
||||
logprobs_field="logprobs",
|
||||
temperature=1.0,
|
||||
gen_temperature=1.0,
|
||||
kd_temperature=1.0,
|
||||
):
|
||||
self.logprobs_field = logprobs_field
|
||||
self.temperature = temperature
|
||||
self.gen_temperature = gen_temperature
|
||||
self.kd_temperature = kd_temperature
|
||||
|
||||
super().__init__(
|
||||
prompter,
|
||||
@@ -531,14 +533,25 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
position_logprobs, dtype=torch.float
|
||||
)
|
||||
|
||||
# Apply temperature scaling at data load time
|
||||
# log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T)
|
||||
position_logprobs_tensor = position_logprobs_tensor / self.temperature
|
||||
# normalize to probabilities so they sum up to 1
|
||||
position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp(
|
||||
position_logprobs_tensor, dim=0, keepdim=True
|
||||
)
|
||||
if self.kd_temperature != self.gen_temperature:
|
||||
#
|
||||
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
||||
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
||||
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
||||
#
|
||||
# Convert from log to probability
|
||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||
# Exponentiate by factor (T1 / T2)
|
||||
exponent = self.gen_temperature / self.kd_temperature
|
||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||
# Re-normalize
|
||||
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||
dim=0, keepdim=True
|
||||
)
|
||||
# Convert back to log
|
||||
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
||||
|
||||
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
||||
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
@@ -593,13 +606,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
|
||||
}
|
||||
|
||||
strategy_cls = ChatTemplateStrategy
|
||||
if logprobs_field := ds_cfg.get("logprobs_field"):
|
||||
strategy_params["logprobs_field"] = logprobs_field
|
||||
if temperature := ds_cfg.get("temperature"):
|
||||
strategy_params["temperature"] = temperature
|
||||
|
||||
if cfg.trainer == "kd" or logprobs_field:
|
||||
if cfg.trainer == "kd":
|
||||
strategy_cls = ChatTemplateStrategyWithKD
|
||||
if logprobs_field := ds_cfg.get("logprobs_field"):
|
||||
strategy_params["logprobs_field"] = logprobs_field
|
||||
if gen_temperature := ds_cfg.get("temperature"):
|
||||
strategy_params["gen_temperature"] = gen_temperature
|
||||
if kd_temperature := cfg.get("kd_temperature"):
|
||||
strategy_params["kd_temperature"] = kd_temperature
|
||||
|
||||
strategy = strategy_cls(
|
||||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||
|
||||
@@ -178,6 +178,7 @@ class SFTDataset(BaseModel):
|
||||
message_field_training: Optional[str] = None
|
||||
message_field_training_detail: Optional[str] = None
|
||||
logprobs_field: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
roles_to_train: Optional[List[str]] = None
|
||||
train_on_eos: Optional[str] = None
|
||||
roles: Optional[Dict[str, List[str]]] = None
|
||||
@@ -626,6 +627,8 @@ class AxolotlInputConfig(
|
||||
kd_ce_alpha: Optional[
|
||||
float
|
||||
] = None # loss coefficient for cross-entropy loss during KD
|
||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||
|
||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user