diff --git a/src/axolotl/core/trainer_builder/rl.py b/src/axolotl/core/trainer_builder/rl.py index d2a4785ef..df2b32904 100644 --- a/src/axolotl/core/trainer_builder/rl.py +++ b/src/axolotl/core/trainer_builder/rl.py @@ -37,6 +37,45 @@ class HFRLTrainerBuilder(TrainerBuilderBase): callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) return callbacks + def _get_trainer_cls(self, trainer_kwargs: dict): + """ + Returns trainer_cls and trainer_cls_args + """ + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + trainer_cls = plugin_manager.get_trainer_cls(self.cfg) + trainer_cls_args = [] # type: ignore + + if trainer_cls is not None: + return trainer_cls, trainer_cls_args + + if self.cfg.rl is RLType.GRPO: + trainer_cls = GRPOStrategy.get_trainer_class( + sequence_parallel=self.cfg.sequence_parallel_degree > 1 + ) + trainer_cls_args = [self.model] + trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) + + trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) + + elif self.cfg.rl in [RLType.DPO, RLType.IPO]: + trainer_cls = DPOStrategy.get_trainer_class() + trainer_cls_args = [self.model, self.model_ref] + + elif self.cfg.rl is RLType.ORPO: + trainer_cls = AxolotlORPOTrainer + trainer_cls_args = [self.model] + elif self.cfg.rl is RLType.KTO: + trainer_cls = AxolotlKTOTrainer + trainer_cls_args = [self.model] + elif self.cfg.rl is RLType.SIMPO: + trainer_cls = AxolotlCPOTrainer + trainer_cls_args = [self.model] + else: + raise ValueError(f"Unsupported RL: {self.cfg.rl}") + + return trainer_cls, trainer_cls_args + def build_training_arguments(self, total_num_steps): training_args_kwargs = self._set_base_training_args( total_num_steps=total_num_steps @@ -152,31 +191,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): trainer_kwargs["precompute_ref_log_probs"] = ( self.cfg.precompute_ref_log_probs ) - if self.cfg.rl is RLType.GRPO: - trainer_cls = GRPOStrategy.get_trainer_class( - sequence_parallel=self.cfg.sequence_parallel_degree > 1 - ) - trainer_cls_args = [self.model] - trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) - trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) - elif self.cfg.rl in [RLType.DPO, RLType.IPO]: - trainer_cls = DPOStrategy.get_trainer_class() - trainer_cls_args = [self.model, self.model_ref] - elif self.cfg.rl is RLType.ORPO: - trainer_cls = AxolotlORPOTrainer - trainer_cls_args = [self.model] - elif self.cfg.rl is RLType.KTO: - trainer_cls = AxolotlKTOTrainer - trainer_cls_args = [self.model] - elif self.cfg.rl is RLType.SIMPO: - trainer_cls = AxolotlCPOTrainer - trainer_cls_args = [self.model] - else: - raise ValueError(f"Unsupported RL: {self.cfg.rl}") - if self.cfg.plugins: - plugin_manager = PluginManager.get_instance() - trainer_cls = plugin_manager.get_trainer_cls(self.cfg) + trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs) sig = inspect.signature(trainer_cls) if "tokenizer" in sig.parameters.keys():