moved some DPOTrainer args to DPOConfig for future trl release
This commit is contained in:
@@ -43,7 +43,7 @@ s3fs>=2024.5.0
|
||||
gcsfs>=2024.5.0
|
||||
# adlfs
|
||||
|
||||
trl==0.12.0
|
||||
trl @ git++https://github.com/huggingface/trl.git@5e90682836969310e16ed8aa711dd429f85863b7
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
|
||||
@@ -1926,16 +1926,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
else:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
training_args_kwargs["max_target_length"] = None
|
||||
if self.cfg.max_prompt_len is not None:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
|
||||
if self.cfg.rl == "ipo":
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
training_args_kwargs["precompute_ref_log_probs"] = self.cfg.precompute_ref_log_probs
|
||||
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
output_dir=self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
@@ -1955,27 +1971,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
def build(self, total_num_steps):
|
||||
training_args = self.build_training_arguments(total_num_steps)
|
||||
dpo_trainer_kwargs = {}
|
||||
if self.cfg.rl == "ipo":
|
||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
|
||||
if self.eval_dataset:
|
||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||
if self.cfg.adapter and self.peft_config:
|
||||
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
|
||||
if self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
|
||||
# these aren't used for the ORPO trainer
|
||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["max_target_length"] = None
|
||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
Reference in New Issue
Block a user