diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index a6e8355f4..0ceb80008 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -120,6 +120,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.use_wandb: training_args_kwargs["run_name"] = self.cfg.wandb_name + if self.cfg.max_prompt_len: + training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len + else: + training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len + training_args_cls = None blocklist_args_kwargs = [] if self.cfg.rl is RLType.SIMPO: @@ -129,10 +134,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.cpo_alpha is not None: training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha + # Handle when max_prompt_length == max_length from defaults + # CPOTrainer requires strictly less than + if ( + training_args_kwargs["max_prompt_length"] + == training_args_kwargs["max_length"] + ): + training_args_kwargs["max_prompt_length"] -= 1 + elif self.cfg.rl is RLType.ORPO: training_args_cls = AxolotlORPOConfig - if self.cfg.max_prompt_len: - training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len elif self.cfg.rl is RLType.KTO: training_args_cls = AxolotlKTOConfig @@ -144,9 +155,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): self.cfg.kto_undesirable_weight or 1.0 ) - if self.cfg.max_prompt_len: - training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len - elif self.cfg.rl is RLType.GRPO: training_args_cls = GRPOStrategy.get_training_args_class() training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 4b40d4085..3aa79c484 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -27,7 +27,6 @@ class DPOStrategy: training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing training_args_kwargs["max_completion_length"] = None training_args_kwargs["max_length"] = cfg.sequence_len - training_args_kwargs["max_prompt_length"] = cfg.sequence_len training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval if cfg.dpo_use_weighting is not None: training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index d612ec8a5..0177b19f6 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -436,8 +436,8 @@ class AxolotlInputConfig( }, ) min_sample_len: int | None = None - max_prompt_len: int = Field( - default=512, + max_prompt_len: int | None = Field( + default=None, json_schema_extra={"description": "maximum prompt length for RL training"}, ) sample_packing: bool | None = Field(