cross entropy loss coefficient during KD
This commit is contained in:
@@ -678,6 +678,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"accelerator_config"
|
||||
] = self.cfg.accelerator_config
|
||||
|
||||
if self.cfg.kd_ce_alpha is not None:
|
||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||
|
||||
training_args_cls = (
|
||||
AxolotlTrainingArguments
|
||||
if not self.cfg.reward_model
|
||||
|
||||
@@ -109,6 +109,10 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
if self.args.kd_ce_alpha > 0:
|
||||
loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd
|
||||
else:
|
||||
loss = loss_kd
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
@@ -119,4 +123,4 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||
loss_kd *= self.accelerator.num_processes
|
||||
|
||||
return (loss_kd, outputs) if return_outputs else loss_kd
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
@@ -169,6 +169,13 @@ class AxolotlTrainingMixins:
|
||||
metadata={"help": "Chat template converting chat messages to text"},
|
||||
)
|
||||
|
||||
kd_ce_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The alpha parameter for SFT cross entropy loss when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
|
||||
@@ -623,6 +623,9 @@ class AxolotlInputConfig(
|
||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||
|
||||
trainer: Optional[Literal["kd"]] = None
|
||||
kd_ce_alpha: Optional[
|
||||
float
|
||||
] = None # loss coefficient for cross-entropy loss during KD
|
||||
|
||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user