From 7be8740c5c0a2ced61eb9c3ac2afa220baf4ff1e Mon Sep 17 00:00:00 2001 From: AlexHT Hung Date: Fri, 19 Sep 2025 18:34:28 +0800 Subject: [PATCH] fix(rl): pass max_prompt_len to training args as max_prompt_length (#3113) * pass max_prompt_len to training args as max_prompt_length * Update rl.py * refactor * format * fix: default for max_prompt_length * fix: defaults for trainer --------- Co-authored-by: NanoCode012 --- src/axolotl/core/builders/rl.py | 18 +++++++++++++----- src/axolotl/core/trainers/dpo/__init__.py | 1 - src/axolotl/utils/schemas/config.py | 4 ++-- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index a6e8355f4..0ceb80008 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -120,6 +120,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.use_wandb: training_args_kwargs["run_name"] = self.cfg.wandb_name + if self.cfg.max_prompt_len: + training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len + else: + training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len + training_args_cls = None blocklist_args_kwargs = [] if self.cfg.rl is RLType.SIMPO: @@ -129,10 +134,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.cpo_alpha is not None: training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha + # Handle when max_prompt_length == max_length from defaults + # CPOTrainer requires strictly less than + if ( + training_args_kwargs["max_prompt_length"] + == training_args_kwargs["max_length"] + ): + training_args_kwargs["max_prompt_length"] -= 1 + elif self.cfg.rl is RLType.ORPO: training_args_cls = AxolotlORPOConfig - if self.cfg.max_prompt_len: - training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len elif self.cfg.rl is RLType.KTO: training_args_cls = AxolotlKTOConfig @@ -144,9 +155,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): self.cfg.kto_undesirable_weight or 1.0 ) - if self.cfg.max_prompt_len: - training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len - elif self.cfg.rl is RLType.GRPO: training_args_cls = GRPOStrategy.get_training_args_class() training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 4b40d4085..3aa79c484 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -27,7 +27,6 @@ class DPOStrategy: training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing training_args_kwargs["max_completion_length"] = None training_args_kwargs["max_length"] = cfg.sequence_len - training_args_kwargs["max_prompt_length"] = cfg.sequence_len training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval if cfg.dpo_use_weighting is not None: training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index d612ec8a5..0177b19f6 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -436,8 +436,8 @@ class AxolotlInputConfig( }, ) min_sample_len: int | None = None - max_prompt_len: int = Field( - default=512, + max_prompt_len: int | None = Field( + default=None, json_schema_extra={"description": "maximum prompt length for RL training"}, ) sample_packing: bool | None = Field(