cross entropy loss coefficient during KD
This commit is contained in:
@@ -678,6 +678,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
"accelerator_config"
|
"accelerator_config"
|
||||||
] = self.cfg.accelerator_config
|
] = self.cfg.accelerator_config
|
||||||
|
|
||||||
|
if self.cfg.kd_ce_alpha is not None:
|
||||||
|
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||||
|
|
||||||
training_args_cls = (
|
training_args_cls = (
|
||||||
AxolotlTrainingArguments
|
AxolotlTrainingArguments
|
||||||
if not self.cfg.reward_model
|
if not self.cfg.reward_model
|
||||||
|
|||||||
@@ -109,6 +109,10 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.args.kd_ce_alpha > 0:
|
||||||
|
loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd
|
||||||
|
else:
|
||||||
|
loss = loss_kd
|
||||||
# Save past state if it exists
|
# Save past state if it exists
|
||||||
# TODO: this needs to be fixed and made cleaner later.
|
# TODO: this needs to be fixed and made cleaner later.
|
||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
@@ -119,4 +123,4 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||||
loss_kd *= self.accelerator.num_processes
|
loss_kd *= self.accelerator.num_processes
|
||||||
|
|
||||||
return (loss_kd, outputs) if return_outputs else loss_kd
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|||||||
@@ -169,6 +169,13 @@ class AxolotlTrainingMixins:
|
|||||||
metadata={"help": "Chat template converting chat messages to text"},
|
metadata={"help": "Chat template converting chat messages to text"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
kd_ce_alpha: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The alpha parameter for SFT cross entropy loss when using KD"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
|
|||||||
@@ -623,6 +623,9 @@ class AxolotlInputConfig(
|
|||||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||||
|
|
||||||
trainer: Optional[Literal["kd"]] = None
|
trainer: Optional[Literal["kd"]] = None
|
||||||
|
kd_ce_alpha: Optional[
|
||||||
|
float
|
||||||
|
] = None # loss coefficient for cross-entropy loss during KD
|
||||||
|
|
||||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||||
|
|||||||
Reference in New Issue
Block a user