support for custom trainer classes from plugins

This commit is contained in:
Wing Lian
2024-12-30 12:20:45 -05:00
parent fa055f9f69
commit c51b0337c1
2 changed files with 32 additions and 0 deletions

View File

@@ -295,6 +295,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def _get_trainer_cls(self):
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
if trainer_cls:
return trainer_cls
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":

View File

@@ -111,6 +111,17 @@ class BasePlugin:
None
"""
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
"""
Returns a custom class for the trainer.
Parameters:
cfg (dict): The global axolotl configuration.
Returns:
class: The class for the trainer.
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
@@ -346,6 +357,22 @@ class PluginManager:
for plugin in self.plugins.values():
plugin.post_lora_load(cfg, model)
def get_trainer_cls(self, cfg):
"""
Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class.
Parameters:
cfg (dict): The configuration for the plugins.
Returns:
object: The trainer class, or None if none was found.
"""
for plugin in self.plugins.values():
trainer_cls = plugin.get_trainer_cls(cfg)
if trainer_cls is not None:
return trainer_cls
return None
def create_optimizer(self, cfg, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.