diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 694bed808..ce4c094f1 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -678,6 +678,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): "accelerator_config" ] = self.cfg.accelerator_config + if self.cfg.kd_ce_alpha is not None: + training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha + training_args_cls = ( AxolotlTrainingArguments if not self.cfg.reward_model diff --git a/src/axolotl/core/trainers/kd.py b/src/axolotl/core/trainers/kd.py index bec18821d..14d37b776 100644 --- a/src/axolotl/core/trainers/kd.py +++ b/src/axolotl/core/trainers/kd.py @@ -109,6 +109,10 @@ class AxolotlKDTrainer(AxolotlTrainer): num_items_in_batch=num_items_in_batch, ) + if self.args.kd_ce_alpha > 0: + loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd + else: + loss = loss_kd # Save past state if it exists # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: @@ -119,4 +123,4 @@ class AxolotlKDTrainer(AxolotlTrainer): if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: loss_kd *= self.accelerator.num_processes - return (loss_kd, outputs) if return_outputs else loss_kd + return (loss, outputs) if return_outputs else loss diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 6a8753e23..d06775993 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -169,6 +169,13 @@ class AxolotlTrainingMixins: metadata={"help": "Chat template converting chat messages to text"}, ) + kd_ce_alpha: Optional[float] = field( + default=None, + metadata={ + "help": "The alpha parameter for SFT cross entropy loss when using KD" + }, + ) + @dataclass class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 83ede4514..75d40af04 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -623,6 +623,9 @@ class AxolotlInputConfig( ] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. trainer: Optional[Literal["kd"]] = None + kd_ce_alpha: Optional[ + float + ] = None # loss coefficient for cross-entropy loss during KD datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore