fix: update handling of trainer_cls in RL
This commit is contained in:
@@ -37,6 +37,45 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
|
def _get_trainer_cls(self, trainer_kwargs: dict):
|
||||||
|
"""
|
||||||
|
Returns trainer_cls and trainer_cls_args
|
||||||
|
"""
|
||||||
|
if self.cfg.plugins:
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||||
|
trainer_cls_args = [] # type: ignore
|
||||||
|
|
||||||
|
if trainer_cls is not None:
|
||||||
|
return trainer_cls, trainer_cls_args
|
||||||
|
|
||||||
|
if self.cfg.rl is RLType.GRPO:
|
||||||
|
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||||
|
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||||
|
)
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
|
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||||
|
|
||||||
|
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||||
|
|
||||||
|
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||||
|
trainer_cls = DPOStrategy.get_trainer_class()
|
||||||
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|
||||||
|
elif self.cfg.rl is RLType.ORPO:
|
||||||
|
trainer_cls = AxolotlORPOTrainer
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
|
elif self.cfg.rl is RLType.KTO:
|
||||||
|
trainer_cls = AxolotlKTOTrainer
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
|
elif self.cfg.rl is RLType.SIMPO:
|
||||||
|
trainer_cls = AxolotlCPOTrainer
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
|
|
||||||
|
return trainer_cls, trainer_cls_args
|
||||||
|
|
||||||
def build_training_arguments(self, total_num_steps):
|
def build_training_arguments(self, total_num_steps):
|
||||||
training_args_kwargs = self._set_base_training_args(
|
training_args_kwargs = self._set_base_training_args(
|
||||||
total_num_steps=total_num_steps
|
total_num_steps=total_num_steps
|
||||||
@@ -152,31 +191,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||||
self.cfg.precompute_ref_log_probs
|
self.cfg.precompute_ref_log_probs
|
||||||
)
|
)
|
||||||
if self.cfg.rl is RLType.GRPO:
|
|
||||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
|
||||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
|
||||||
)
|
|
||||||
trainer_cls_args = [self.model]
|
|
||||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
|
||||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
|
||||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
|
||||||
trainer_cls = DPOStrategy.get_trainer_class()
|
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
|
||||||
elif self.cfg.rl is RLType.ORPO:
|
|
||||||
trainer_cls = AxolotlORPOTrainer
|
|
||||||
trainer_cls_args = [self.model]
|
|
||||||
elif self.cfg.rl is RLType.KTO:
|
|
||||||
trainer_cls = AxolotlKTOTrainer
|
|
||||||
trainer_cls_args = [self.model]
|
|
||||||
elif self.cfg.rl is RLType.SIMPO:
|
|
||||||
trainer_cls = AxolotlCPOTrainer
|
|
||||||
trainer_cls_args = [self.model]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
|
||||||
|
|
||||||
if self.cfg.plugins:
|
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
|
||||||
plugin_manager = PluginManager.get_instance()
|
|
||||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
|
||||||
|
|
||||||
sig = inspect.signature(trainer_cls)
|
sig = inspect.signature(trainer_cls)
|
||||||
if "tokenizer" in sig.parameters.keys():
|
if "tokenizer" in sig.parameters.keys():
|
||||||
|
|||||||
Reference in New Issue
Block a user