From 6ad809287bdf58ee48e98d1ebb8cf1deb9127c27 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 24 Dec 2024 09:26:27 -0500 Subject: [PATCH] better rescaling for temperatures --- src/axolotl/core/trainer_builder.py | 2 + src/axolotl/core/trainers/kd.py | 17 +++++++- src/axolotl/core/training_args.py | 14 +++++- .../prompt_strategies/chat_template.py | 43 +++++++++++++------ .../config/models/input/v0_4_1/__init__.py | 3 ++ 5 files changed, 62 insertions(+), 17 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ce4c094f1..43ebf3170 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -680,6 +680,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.kd_ce_alpha is not None: training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha + if self.cfg.kd_alpha is not None: + training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha training_args_cls = ( AxolotlTrainingArguments diff --git a/src/axolotl/core/trainers/kd.py b/src/axolotl/core/trainers/kd.py index 893c529d8..6047c72ff 100644 --- a/src/axolotl/core/trainers/kd.py +++ b/src/axolotl/core/trainers/kd.py @@ -16,6 +16,7 @@ def kd_loss_function( target_logprobs, target_mask, num_items_in_batch: Optional[int] = None, + kd_temperature: float = 1.0, ): # teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding @@ -28,9 +29,15 @@ def kd_loss_function( ] # [B, teacher_seq_len, vocab_size] # Gather student logits for teacher's top-K tokens + # shape -> [B, teacher_seq_len, K] student_logits_topk = torch.gather( student_logits_for_kd, dim=-1, index=target_token_ids - ) # [B, teacher_seq_len, K] + ) + + # Apply KD temperature to student’s logits: + # z_s(T) = z_s / T + if kd_temperature != 1.0: + student_logits_topk = student_logits_topk / kd_temperature # Convert student top-k logits to logprobs student_logprobs_topk = student_logits_topk - torch.logsumexp( @@ -53,6 +60,10 @@ def kd_loss_function( kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk) kd_loss = kd_loss_per_token.sum() + # 9) Multiply by T^2 (classical KD scaling) + if kd_temperature != 1.0: + kd_loss = kd_loss * (kd_temperature**2) + # Normalize by number of items or mean over valid tokens if num_items_in_batch is not None: # If you know how many items should be considered in the batch @@ -129,7 +140,8 @@ class AxolotlKDTrainer(AxolotlTrainer): # optionally combine with CE loss if self.args.kd_ce_alpha > 0: - loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd + kd_alpha = self.args.kd_alpha + loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd else: loss = loss_kd @@ -166,6 +178,7 @@ class AxolotlKDTrainer(AxolotlTrainer): target_logprobs, target_mask, num_items_in_batch=num_items_in_batch, + kd_temperature=self.args.kd_temperature, ) if self.args.kd_ce_alpha > 0: diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index d06775993..36c43025d 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -172,7 +172,19 @@ class AxolotlTrainingMixins: kd_ce_alpha: Optional[float] = field( default=None, metadata={ - "help": "The alpha parameter for SFT cross entropy loss when using KD" + "help": "The alpha scaling parameter for SFT cross entropy loss when using KD" + }, + ) + + kd_alpha: Optional[float] = field( + default=1.0, + metadata={"help": "The alpha scaling parameter for KD loss"}, + ) + + kd_temperature: Optional[float] = field( + default=1.0, + metadata={ + "help": "the temperature parameter for KL divergence loss when using KD" }, ) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index ea13f634b..223bf14b4 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -474,10 +474,12 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): roles_to_train=None, train_on_eos=None, logprobs_field="logprobs", - temperature=1.0, + gen_temperature=1.0, + kd_temperature=1.0, ): self.logprobs_field = logprobs_field - self.temperature = temperature + self.gen_temperature = gen_temperature + self.kd_temperature = kd_temperature super().__init__( prompter, @@ -531,14 +533,25 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): position_logprobs, dtype=torch.float ) - # Apply temperature scaling at data load time - # log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T) - position_logprobs_tensor = position_logprobs_tensor / self.temperature - # normalize to probabilities so they sum up to 1 - position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp( - position_logprobs_tensor, dim=0, keepdim=True - ) + if self.kd_temperature != self.gen_temperature: + # + # Now we have distribution at T1 in log form, i.e. log p_{T1}(k). + # Next, re-scale to T2 = self.kd_temperature via exponent-based trick + # p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z + # + # Convert from log to probability + teacher_probs_t1 = position_logprobs_tensor.exp() + # Exponentiate by factor (T1 / T2) + exponent = self.gen_temperature / self.kd_temperature + teacher_probs_t2 = teacher_probs_t1**exponent + # Re-normalize + teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( + dim=0, keepdim=True + ) + # Convert back to log + position_logprobs_tensor = torch.log(teacher_probs_t2) + # Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor position_logprobs_scaled = position_logprobs_tensor.tolist() target_logprobs.append(position_logprobs_scaled) @@ -593,13 +606,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None } strategy_cls = ChatTemplateStrategy - if logprobs_field := ds_cfg.get("logprobs_field"): - strategy_params["logprobs_field"] = logprobs_field - if temperature := ds_cfg.get("temperature"): - strategy_params["temperature"] = temperature - if cfg.trainer == "kd" or logprobs_field: + if cfg.trainer == "kd": strategy_cls = ChatTemplateStrategyWithKD + if logprobs_field := ds_cfg.get("logprobs_field"): + strategy_params["logprobs_field"] = logprobs_field + if gen_temperature := ds_cfg.get("temperature"): + strategy_params["gen_temperature"] = gen_temperature + if kd_temperature := cfg.get("kd_temperature"): + strategy_params["kd_temperature"] = kd_temperature strategy = strategy_cls( ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 75d40af04..a0ba7efe9 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -178,6 +178,7 @@ class SFTDataset(BaseModel): message_field_training: Optional[str] = None message_field_training_detail: Optional[str] = None logprobs_field: Optional[str] = None + temperature: Optional[float] = None roles_to_train: Optional[List[str]] = None train_on_eos: Optional[str] = None roles: Optional[Dict[str, List[str]]] = None @@ -626,6 +627,8 @@ class AxolotlInputConfig( kd_ce_alpha: Optional[ float ] = None # loss coefficient for cross-entropy loss during KD + kd_alpha: Optional[float] = None # loss coefficient for KD loss + kd_temperature: Optional[float] = None # temperature for sampling during KD datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore