fix(RL): address plugin rl overwriting trainer_cls (#2697) [skip ci]
* fix: plugin rl overwrite trainer_cls * feat(test): add test to catch trainer_cls is not None
This commit is contained in:
@@ -1195,7 +1195,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
temp_trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
if temp_trainer_cls is not None:
|
||||
trainer_cls = temp_trainer_cls
|
||||
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "tokenizer" in sig.parameters.keys():
|
||||
|
||||
Reference in New Issue
Block a user