fix: update handling of trainer_cls in RL
This commit is contained in:
@@ -37,6 +37,45 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self, trainer_kwargs: dict):
|
||||
"""
|
||||
Returns trainer_cls and trainer_cls_args
|
||||
"""
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
trainer_cls_args = [] # type: ignore
|
||||
|
||||
if trainer_cls is not None:
|
||||
return trainer_cls, trainer_cls_args
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
)
|
||||
trainer_cls_args = [self.model]
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
trainer_cls = DPOStrategy.get_trainer_class()
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
trainer_cls = AxolotlKTOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
elif self.cfg.rl is RLType.SIMPO:
|
||||
trainer_cls = AxolotlCPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
return trainer_cls, trainer_cls_args
|
||||
|
||||
def build_training_arguments(self, total_num_steps):
|
||||
training_args_kwargs = self._set_base_training_args(
|
||||
total_num_steps=total_num_steps
|
||||
@@ -152,31 +191,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||
self.cfg.precompute_ref_log_probs
|
||||
)
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
)
|
||||
trainer_cls_args = [self.model]
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
trainer_cls = DPOStrategy.get_trainer_class()
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
trainer_cls = AxolotlKTOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
elif self.cfg.rl is RLType.SIMPO:
|
||||
trainer_cls = AxolotlCPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
|
||||
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "tokenizer" in sig.parameters.keys():
|
||||
|
||||
Reference in New Issue
Block a user