diff --git a/src/axolotl/core/trainer_builder/rl.py b/src/axolotl/core/trainer_builder/rl.py index b584d3e88..578f68517 100644 --- a/src/axolotl/core/trainer_builder/rl.py +++ b/src/axolotl/core/trainer_builder/rl.py @@ -208,6 +208,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase): trainer_kwargs["dataset_tags"] = [ d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() ] + + trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( + trainer_kwargs, trainer_cls + ) + trainer = trainer_cls( *trainer_cls_args, args=training_args,