FIX: TRL trainer preprocessing step was running in one process (#1583)
* FIX: TRL trainer preprocessing step was running in one process * FIX: Changed so that dataset_num_proc is sent to CPO, KTO and ORPO trainer args and directly to the trainer when DPO * FIX: Changed back to only support ORPO for now, since KTO is handled in another way --------- Co-authored-by: Ali Mosavian <ali.mosavian@kry.se>
This commit is contained in:
@@ -1462,6 +1462,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
else:
|
||||
training_args_kwargs["evaluation_strategy"] = "no"
|
||||
|
||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||
training_args_kwargs["bf16"] = True
|
||||
|
||||
@@ -1520,6 +1521,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_cls = TrainingArguments
|
||||
if self.cfg.rl == "orpo":
|
||||
training_args_cls = ORPOConfig
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
training_args = training_args_cls(
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
@@ -1564,6 +1566,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs["max_target_length"] = None
|
||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||
if self.cfg.rl == "dpo":
|
||||
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
Reference in New Issue
Block a user