fix(rl): pass max_prompt_len to training args as max_prompt_length (#3113)
* pass max_prompt_len to training args as max_prompt_length * Update rl.py * refactor * format * fix: default for max_prompt_length * fix: defaults for trainer --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -120,6 +120,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||||
|
|
||||||
|
if self.cfg.max_prompt_len:
|
||||||
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
else:
|
||||||
|
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
training_args_cls = None
|
training_args_cls = None
|
||||||
blocklist_args_kwargs = []
|
blocklist_args_kwargs = []
|
||||||
if self.cfg.rl is RLType.SIMPO:
|
if self.cfg.rl is RLType.SIMPO:
|
||||||
@@ -129,10 +134,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.cpo_alpha is not None:
|
if self.cfg.cpo_alpha is not None:
|
||||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||||
|
|
||||||
|
# Handle when max_prompt_length == max_length from defaults
|
||||||
|
# CPOTrainer requires strictly less than
|
||||||
|
if (
|
||||||
|
training_args_kwargs["max_prompt_length"]
|
||||||
|
== training_args_kwargs["max_length"]
|
||||||
|
):
|
||||||
|
training_args_kwargs["max_prompt_length"] -= 1
|
||||||
|
|
||||||
elif self.cfg.rl is RLType.ORPO:
|
elif self.cfg.rl is RLType.ORPO:
|
||||||
training_args_cls = AxolotlORPOConfig
|
training_args_cls = AxolotlORPOConfig
|
||||||
if self.cfg.max_prompt_len:
|
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
|
||||||
|
|
||||||
elif self.cfg.rl is RLType.KTO:
|
elif self.cfg.rl is RLType.KTO:
|
||||||
training_args_cls = AxolotlKTOConfig
|
training_args_cls = AxolotlKTOConfig
|
||||||
@@ -144,9 +155,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kto_undesirable_weight or 1.0
|
self.cfg.kto_undesirable_weight or 1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.max_prompt_len:
|
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
|
||||||
|
|
||||||
elif self.cfg.rl is RLType.GRPO:
|
elif self.cfg.rl is RLType.GRPO:
|
||||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ class DPOStrategy:
|
|||||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||||
training_args_kwargs["max_completion_length"] = None
|
training_args_kwargs["max_completion_length"] = None
|
||||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
|
||||||
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||||
if cfg.dpo_use_weighting is not None:
|
if cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||||
|
|||||||
@@ -436,8 +436,8 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
min_sample_len: int | None = None
|
min_sample_len: int | None = None
|
||||||
max_prompt_len: int = Field(
|
max_prompt_len: int | None = Field(
|
||||||
default=512,
|
default=None,
|
||||||
json_schema_extra={"description": "maximum prompt length for RL training"},
|
json_schema_extra={"description": "maximum prompt length for RL training"},
|
||||||
)
|
)
|
||||||
sample_packing: bool | None = Field(
|
sample_packing: bool | None = Field(
|
||||||
|
|||||||
Reference in New Issue
Block a user