fix: update handling of trainer_cls in RL

This commit is contained in:
NanoCode012
2025-05-16 14:23:28 +07:00
parent 0b40f2aaf6
commit 00bfdb6b2b

View File

@@ -37,6 +37,45 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def _get_trainer_cls(self, trainer_kwargs: dict):
"""
Returns trainer_cls and trainer_cls_args
"""
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
trainer_cls_args = [] # type: ignore
if trainer_cls is not None:
return trainer_cls, trainer_cls_args
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1
)
trainer_cls_args = [self.model]
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl is RLType.ORPO:
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl is RLType.KTO:
trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl is RLType.SIMPO:
trainer_cls = AxolotlCPOTrainer
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
return trainer_cls, trainer_cls_args
def build_training_arguments(self, total_num_steps):
training_args_kwargs = self._set_base_training_args(
total_num_steps=total_num_steps
@@ -152,31 +191,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1
)
trainer_cls_args = [self.model]
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl is RLType.ORPO:
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl is RLType.KTO:
trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl is RLType.SIMPO:
trainer_cls = AxolotlCPOTrainer
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters.keys():