From 1e1921b79421abc54feaa3e01e7a927b9343336c Mon Sep 17 00:00:00 2001 From: Ali Mosavian Date: Tue, 14 May 2024 14:51:17 +0200 Subject: [PATCH] FIX: max_length and max_prompt_length was not being sent to ORPOTrainer (#1584) * FIX: TRL trainer preprocessing step was running in one process * FIX: max_length and max_prompt_length was not being sent to ORPOTrainer * FIX: Change ORPO max prompt length to 1/4 of max length, otherwise we get strange behaviour * FIX: Removed change from a different PR * FIX: Black fix * explicitly set max prompt len for orpo config --------- Co-authored-by: Ali Mosavian Co-authored-by: Wing Lian --- src/axolotl/core/trainer_builder.py | 3 +++ src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 742a88633..2f38b12dc 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1526,6 +1526,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rl == "orpo": training_args_cls = ORPOConfig training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes + training_args_kwargs["max_length"] = self.cfg.sequence_len + if self.cfg.max_prompt_len: + training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len training_args = training_args_cls( per_device_train_batch_size=self.cfg.micro_batch_size, diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 0a2442d50..f1c12b2ba 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -517,6 +517,9 @@ class AxolotlInputConfig( sequence_len: int = Field(default=512) min_sample_len: Optional[int] = None + max_prompt_len: int = Field( + default=512, metadata={"help": "maximum prompt length for RL training"} + ) sample_packing: Optional[bool] = None eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None