cross entropy loss coefficient during KD

This commit is contained in:
Wing Lian
2024-12-19 01:42:21 -05:00
parent b592c05b93
commit ae545e0165
4 changed files with 18 additions and 1 deletions

View File

@@ -678,6 +678,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"accelerator_config"
] = self.cfg.accelerator_config
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
training_args_cls = (
AxolotlTrainingArguments
if not self.cfg.reward_model

View File

@@ -109,6 +109,10 @@ class AxolotlKDTrainer(AxolotlTrainer):
num_items_in_batch=num_items_in_batch,
)
if self.args.kd_ce_alpha > 0:
loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd
else:
loss = loss_kd
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
@@ -119,4 +123,4 @@ class AxolotlKDTrainer(AxolotlTrainer):
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss_kd *= self.accelerator.num_processes
return (loss_kd, outputs) if return_outputs else loss_kd
return (loss, outputs) if return_outputs else loss

View File

@@ -169,6 +169,13 @@ class AxolotlTrainingMixins:
metadata={"help": "Chat template converting chat messages to text"},
)
kd_ce_alpha: Optional[float] = field(
default=None,
metadata={
"help": "The alpha parameter for SFT cross entropy loss when using KD"
},
)
@dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):

View File

@@ -623,6 +623,9 @@ class AxolotlInputConfig(
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
trainer: Optional[Literal["kd"]] = None
kd_ce_alpha: Optional[
float
] = None # loss coefficient for cross-entropy loss during KD
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore