better rescaling for temperatures

This commit is contained in:
Wing Lian
2024-12-24 09:26:27 -05:00
parent e376e00386
commit 6ad809287b
5 changed files with 62 additions and 17 deletions

View File

@@ -680,6 +680,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
training_args_cls = (
AxolotlTrainingArguments

View File

@@ -16,6 +16,7 @@ def kd_loss_function(
target_logprobs,
target_mask,
num_items_in_batch: Optional[int] = None,
kd_temperature: float = 1.0,
):
# teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding
@@ -28,9 +29,15 @@ def kd_loss_function(
] # [B, teacher_seq_len, vocab_size]
# Gather student logits for teacher's top-K tokens
# shape -> [B, teacher_seq_len, K]
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
)
# Apply KD temperature to students logits:
# z_s(T) = z_s / T
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature
# Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
@@ -53,6 +60,10 @@ def kd_loss_function(
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum()
# 9) Multiply by T^2 (classical KD scaling)
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Normalize by number of items or mean over valid tokens
if num_items_in_batch is not None:
# If you know how many items should be considered in the batch
@@ -129,7 +140,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
# optionally combine with CE loss
if self.args.kd_ce_alpha > 0:
loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else:
loss = loss_kd
@@ -166,6 +178,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
target_logprobs,
target_mask,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
)
if self.args.kd_ce_alpha > 0:

View File

@@ -172,7 +172,19 @@ class AxolotlTrainingMixins:
kd_ce_alpha: Optional[float] = field(
default=None,
metadata={
"help": "The alpha parameter for SFT cross entropy loss when using KD"
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
},
)
kd_alpha: Optional[float] = field(
default=1.0,
metadata={"help": "The alpha scaling parameter for KD loss"},
)
kd_temperature: Optional[float] = field(
default=1.0,
metadata={
"help": "the temperature parameter for KL divergence loss when using KD"
},
)

View File

@@ -474,10 +474,12 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
roles_to_train=None,
train_on_eos=None,
logprobs_field="logprobs",
temperature=1.0,
gen_temperature=1.0,
kd_temperature=1.0,
):
self.logprobs_field = logprobs_field
self.temperature = temperature
self.gen_temperature = gen_temperature
self.kd_temperature = kd_temperature
super().__init__(
prompter,
@@ -531,14 +533,25 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
position_logprobs, dtype=torch.float
)
# Apply temperature scaling at data load time
# log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T)
position_logprobs_tensor = position_logprobs_tensor / self.temperature
# normalize to probabilities so they sum up to 1
position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp(
position_logprobs_tensor, dim=0, keepdim=True
)
if self.kd_temperature != self.gen_temperature:
#
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True
)
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
position_logprobs_scaled = position_logprobs_tensor.tolist()
target_logprobs.append(position_logprobs_scaled)
@@ -593,13 +606,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
}
strategy_cls = ChatTemplateStrategy
if logprobs_field := ds_cfg.get("logprobs_field"):
strategy_params["logprobs_field"] = logprobs_field
if temperature := ds_cfg.get("temperature"):
strategy_params["temperature"] = temperature
if cfg.trainer == "kd" or logprobs_field:
if cfg.trainer == "kd":
strategy_cls = ChatTemplateStrategyWithKD
if logprobs_field := ds_cfg.get("logprobs_field"):
strategy_params["logprobs_field"] = logprobs_field
if gen_temperature := ds_cfg.get("temperature"):
strategy_params["gen_temperature"] = gen_temperature
if kd_temperature := cfg.get("kd_temperature"):
strategy_params["kd_temperature"] = kd_temperature
strategy = strategy_cls(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params

View File

@@ -178,6 +178,7 @@ class SFTDataset(BaseModel):
message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None
logprobs_field: Optional[str] = None
temperature: Optional[float] = None
roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
@@ -626,6 +627,8 @@ class AxolotlInputConfig(
kd_ce_alpha: Optional[
float
] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore