moved some DPOTrainer args to DPOConfig for future trl release
This commit is contained in:
@@ -43,7 +43,7 @@ s3fs>=2024.5.0
|
|||||||
gcsfs>=2024.5.0
|
gcsfs>=2024.5.0
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl==0.12.0
|
trl @ git++https://github.com/huggingface/trl.git@5e90682836969310e16ed8aa711dd429f85863b7
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|
||||||
|
|||||||
@@ -1926,16 +1926,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kto_undesirable_weight or 1.0
|
self.cfg.kto_undesirable_weight or 1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
if self.cfg.max_prompt_len:
|
if self.cfg.max_prompt_len:
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
else:
|
else:
|
||||||
training_args_cls = AxolotlDPOConfig
|
training_args_cls = AxolotlDPOConfig
|
||||||
|
|
||||||
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
|
training_args_kwargs["max_target_length"] = None
|
||||||
|
if self.cfg.max_prompt_len is not None:
|
||||||
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
if self.cfg.dpo_use_weighting is not None:
|
if self.cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||||
|
|
||||||
|
if self.cfg.rl == "ipo":
|
||||||
|
training_args_kwargs["loss_type"] = "ipo"
|
||||||
|
if self.cfg.dpo_label_smoothing:
|
||||||
|
training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
|
|
||||||
|
if self.cfg.precompute_ref_log_probs is not None:
|
||||||
|
training_args_kwargs["precompute_ref_log_probs"] = self.cfg.precompute_ref_log_probs
|
||||||
|
|
||||||
|
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||||
|
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
output_dir=self.cfg.output_dir,
|
output_dir=self.cfg.output_dir,
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
@@ -1955,27 +1971,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
training_args = self.build_training_arguments(total_num_steps)
|
training_args = self.build_training_arguments(total_num_steps)
|
||||||
dpo_trainer_kwargs = {}
|
dpo_trainer_kwargs = {}
|
||||||
if self.cfg.rl == "ipo":
|
|
||||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
|
||||||
if self.cfg.dpo_label_smoothing:
|
|
||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
||||||
if self.cfg.precompute_ref_log_probs is not None:
|
|
||||||
dpo_trainer_kwargs[
|
|
||||||
"precompute_ref_log_probs"
|
|
||||||
] = self.cfg.precompute_ref_log_probs
|
|
||||||
if self.cfg.rl in ["dpo", "ipo"]:
|
if self.cfg.rl in ["dpo", "ipo"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|
||||||
# these aren't used for the ORPO trainer
|
|
||||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
|
||||||
dpo_trainer_kwargs["max_target_length"] = None
|
|
||||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
|
||||||
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
|||||||
Reference in New Issue
Block a user