diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e47c09d51..aab9a80b8 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -48,6 +48,7 @@ from trl import ( ) from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length +from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils import is_comet_available, is_mlflow_available @@ -1147,6 +1148,12 @@ class TrainerBuilderBase(abc.ABC): def get_callbacks(self) -> List[TrainerCallback]: callbacks = [] + + plugin_manager = PluginManager.get_instance() + callbacks.extend( + plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model) + ) + if self.cfg.use_wandb: callbacks.append( SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) @@ -1173,11 +1180,17 @@ class TrainerBuilderBase(abc.ABC): return callbacks - @abstractmethod def get_post_trainer_create_callbacks(self, trainer): """ Callbacks added after the trainer is created, usually b/c these need access to the trainer """ + callbacks = [] + + plugin_manager = PluginManager.get_instance() + callbacks.extend( + plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer) + ) + return callbacks def hook_pre_create_training_args(self, training_arguments_kwargs): # TODO @@ -1223,7 +1236,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return callbacks def get_post_trainer_create_callbacks(self, trainer): - callbacks = [] + callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) if self.cfg.use_wandb and self.cfg.eval_table_size > 0: LogPredictionCallback = log_prediction_callback_factory( trainer, self.tokenizer, "wandb" @@ -1791,7 +1804,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): return callbacks def get_post_trainer_create_callbacks(self, trainer): - callbacks = [] + callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) return callbacks def build_training_arguments(self, total_num_steps): @@ -2000,11 +2013,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase): """ def get_callbacks(self): - callbacks = [] + callbacks = super().get_callbacks() return callbacks def get_post_trainer_create_callbacks(self, trainer): - callbacks = [] + callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) return callbacks def build(self, total_num_steps): diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index e2bd79bc4..43afa431a 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -18,9 +18,10 @@ Plugins can be used to integrate third-party models, modify the training process To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods. """ +import collections import importlib import logging -from typing import List +from typing import OrderedDict class BasePlugin: @@ -47,7 +48,7 @@ class BasePlugin: Initializes the BasePlugin. """ - def register(self, cfg): + def register(self, cfg): # pylint: disable=unused-argument """ Registers the plugin with the given configuration. @@ -63,7 +64,7 @@ class BasePlugin: Returns a pydantic model for the plugin's input arguments. """ - def pre_model_load(self, cfg): + def pre_model_load(self, cfg): # pylint: disable=unused-argument """ Performs actions before the model is loaded. @@ -74,7 +75,7 @@ class BasePlugin: None """ - def post_model_load(self, cfg, model): + def post_model_load(self, cfg, model): # pylint: disable=unused-argument """ Performs actions after the model is loaded. @@ -86,7 +87,7 @@ class BasePlugin: None """ - def pre_lora_load(self, cfg, model): + def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument """ Performs actions before LoRA weights are loaded. @@ -98,7 +99,7 @@ class BasePlugin: None """ - def post_lora_load(self, cfg, model): + def post_lora_load(self, cfg, model): # pylint: disable=unused-argument """ Performs actions after LoRA weights are loaded. @@ -110,7 +111,7 @@ class BasePlugin: None """ - def create_optimizer(self, cfg, trainer): + def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument """ Creates and returns an optimizer for training. @@ -122,7 +123,9 @@ class BasePlugin: object: The created optimizer. """ - def create_lr_scheduler(self, cfg, trainer, optimizer): + def create_lr_scheduler( + self, cfg, trainer, optimizer + ): # pylint: disable=unused-argument """ Creates and returns a learning rate scheduler. @@ -135,7 +138,7 @@ class BasePlugin: object: The created learning rate scheduler. """ - def add_callbacks_pre_trainer(self, cfg, model): + def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument """ Adds callbacks to the trainer before training. @@ -146,8 +149,11 @@ class BasePlugin: Returns: List[callable]: A list of callback functions to be added to the TrainingArgs """ + return [] - def add_callbacks_post_trainer(self, cfg, trainer): + def add_callbacks_post_trainer( + self, cfg, trainer + ): # pylint: disable=unused-argument """ Adds callbacks to the trainer after training. @@ -158,8 +164,9 @@ class BasePlugin: Returns: List[callable]: A list of callback functions to be added to the TrainingArgs """ + return [] - def post_train(self, cfg, model): + def post_train(self, cfg, model): # pylint: disable=unused-argument """ Performs actions after training is complete. @@ -171,7 +178,7 @@ class BasePlugin: None """ - def post_train_unload(self, cfg): + def post_train_unload(self, cfg): # pylint: disable=unused-argument """ Performs actions after training is complete and the model is unloaded. @@ -227,7 +234,7 @@ class PluginManager: pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. """ - plugins: List[BasePlugin] = [] + plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict() _instance = None @@ -237,7 +244,7 @@ class PluginManager: """ if cls._instance is None: cls._instance = super(PluginManager, cls).__new__(cls) - cls._instance.plugins: List[BasePlugin] = [] + cls._instance.plugins = collections.OrderedDict() return cls._instance @staticmethod @@ -265,7 +272,7 @@ class PluginManager: """ try: plugin = load_plugin(plugin_name) - self.plugins.append(plugin) + self.plugins[plugin_name] = plugin except ImportError: logging.error(f"Failed to load plugin: {plugin_name}") @@ -277,7 +284,7 @@ class PluginManager: list[str]: A list of Pydantic classes for all registered plugins' input arguments.' """ input_args = [] - for plugin in self.plugins: + for plugin in self.plugins.values(): input_args_from_plugin = plugin.get_input_args() if input_args_from_plugin is not None: input_args.append(input_args_from_plugin) @@ -293,7 +300,7 @@ class PluginManager: Returns: None """ - for plugin in self.plugins: + for plugin in self.plugins.values(): plugin.pre_model_load(cfg) def post_model_load(self, cfg, model): @@ -307,7 +314,7 @@ class PluginManager: Returns: None """ - for plugin in self.plugins: + for plugin in self.plugins.values(): plugin.post_model_load(cfg, model) def pre_lora_load(self, cfg, model): @@ -321,7 +328,7 @@ class PluginManager: Returns: None """ - for plugin in self.plugins: + for plugin in self.plugins.values(): plugin.pre_lora_load(cfg, model) def post_lora_load(self, cfg, model): @@ -335,7 +342,7 @@ class PluginManager: Returns: None """ - for plugin in self.plugins: + for plugin in self.plugins.values(): plugin.post_lora_load(cfg, model) def create_optimizer(self, cfg, trainer): @@ -349,7 +356,7 @@ class PluginManager: Returns: object: The created optimizer, or None if none was found. """ - for plugin in self.plugins: + for plugin in self.plugins.values(): optimizer = plugin.create_optimizer(cfg, trainer) if optimizer is not None: return optimizer @@ -367,7 +374,7 @@ class PluginManager: Returns: object: The created learning rate scheduler, or None if none was found. """ - for plugin in self.plugins: + for plugin in self.plugins.values(): scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer) if scheduler is not None: return scheduler @@ -385,7 +392,7 @@ class PluginManager: List[callable]: A list of callback functions to be added to the TrainingArgs. """ callbacks = [] - for plugin in self.plugins: + for plugin in self.plugins.values(): callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model)) return callbacks @@ -401,7 +408,7 @@ class PluginManager: List[callable]: A list of callback functions to be added to the TrainingArgs. """ callbacks = [] - for plugin in self.plugins: + for plugin in self.plugins.values(): callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer)) return callbacks @@ -416,5 +423,5 @@ class PluginManager: Returns: None """ - for plugin in self.plugins: + for plugin in self.plugins.values(): plugin.post_train_unload(cfg)