FIX: max_length and max_prompt_length was not being sent to ORPOTrainer (#1584)
* FIX: TRL trainer preprocessing step was running in one process * FIX: max_length and max_prompt_length was not being sent to ORPOTrainer * FIX: Change ORPO max prompt length to 1/4 of max length, otherwise we get strange behaviour * FIX: Removed change from a different PR * FIX: Black fix * explicitly set max prompt len for orpo config --------- Co-authored-by: Ali Mosavian <ali.mosavian@kry.se> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -1526,6 +1526,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.rl == "orpo":
|
||||
training_args_cls = ORPOConfig
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
training_args = training_args_cls(
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
|
||||
@@ -517,6 +517,9 @@ class AxolotlInputConfig(
|
||||
|
||||
sequence_len: int = Field(default=512)
|
||||
min_sample_len: Optional[int] = None
|
||||
max_prompt_len: int = Field(
|
||||
default=512, metadata={"help": "maximum prompt length for RL training"}
|
||||
)
|
||||
sample_packing: Optional[bool] = None
|
||||
eval_sample_packing: Optional[bool] = None
|
||||
pad_to_sequence_len: Optional[bool] = None
|
||||
|
||||
Reference in New Issue
Block a user