FIX: TRL trainer preprocessing step was running in one process (#1583)
* FIX: TRL trainer preprocessing step was running in one process * FIX: Changed so that dataset_num_proc is sent to CPO, KTO and ORPO trainer args and directly to the trainer when DPO * FIX: Changed back to only support ORPO for now, since KTO is handled in another way --------- Co-authored-by: Ali Mosavian <ali.mosavian@kry.se>
This commit is contained in:
@@ -1462,6 +1462,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["evaluation_strategy"] = "no"
|
training_args_kwargs["evaluation_strategy"] = "no"
|
||||||
|
|
||||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||||
training_args_kwargs["bf16"] = True
|
training_args_kwargs["bf16"] = True
|
||||||
|
|
||||||
@@ -1520,6 +1521,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_cls = TrainingArguments
|
training_args_cls = TrainingArguments
|
||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
training_args_cls = ORPOConfig
|
training_args_cls = ORPOConfig
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
|
||||||
training_args = training_args_cls(
|
training_args = training_args_cls(
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
@@ -1564,6 +1566,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["max_target_length"] = None
|
dpo_trainer_kwargs["max_target_length"] = None
|
||||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
dpo_trainer_kwargs["generate_during_eval"] = True
|
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||||
|
if self.cfg.rl == "dpo":
|
||||||
|
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
|||||||
Reference in New Issue
Block a user