From a0d6d8895ebd1f3f6c3649a8a4d036a93039b5a0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 30 Dec 2024 12:20:45 -0500 Subject: [PATCH] support for custom trainer classes from plugins --- src/axolotl/core/trainer_builder.py | 5 +++++ src/axolotl/integrations/base.py | 27 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index def1d7a26..7eadd3e59 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -295,6 +295,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return callbacks def _get_trainer_cls(self): + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + trainer_cls = plugin_manager.get_trainer_cls(self.cfg) + if trainer_cls: + return trainer_cls if self.cfg.relora_steps: return ReLoRATrainer if self.cfg.model_config_type == "mamba": diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index a271c59d1..26f2f8a6f 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -111,6 +111,17 @@ class BasePlugin: None """ + def get_trainer_cls(self, cfg): # pylint: disable=unused-argument): + """ + Returns a custom class for the trainer. + + Parameters: + cfg (dict): The global axolotl configuration. + + Returns: + class: The class for the trainer. + """ + def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument """ Creates and returns an optimizer for training. @@ -346,6 +357,22 @@ class PluginManager: for plugin in self.plugins.values(): plugin.post_lora_load(cfg, model) + def get_trainer_cls(self, cfg): + """ + Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class. + + Parameters: + cfg (dict): The configuration for the plugins. + + Returns: + object: The trainer class, or None if none was found. + """ + for plugin in self.plugins.values(): + trainer_cls = plugin.get_trainer_cls(cfg) + if trainer_cls is not None: + return trainer_cls + return None + def create_optimizer(self, cfg, trainer): """ Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.