Add plugin manager's callback hooks to training flow (#2006)
* Add plugin manager's callback hooks to training flow * Use .values() instead of .items()
This commit is contained in:
@@ -48,6 +48,7 @@ from trl import (
|
||||
)
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
@@ -1147,6 +1148,12 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
def get_callbacks(self) -> List[TrainerCallback]:
|
||||
callbacks = []
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||
)
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
@@ -1173,11 +1180,17 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
return callbacks
|
||||
|
||||
@abstractmethod
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
"""
|
||||
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||
"""
|
||||
callbacks = []
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
|
||||
)
|
||||
return callbacks
|
||||
|
||||
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||
# TODO
|
||||
@@ -1223,7 +1236,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer, "wandb"
|
||||
@@ -1791,7 +1804,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def build_training_arguments(self, total_num_steps):
|
||||
@@ -2000,11 +2013,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = []
|
||||
callbacks = super().get_callbacks()
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def build(self, total_num_steps):
|
||||
|
||||
@@ -18,9 +18,10 @@ Plugins can be used to integrate third-party models, modify the training process
|
||||
|
||||
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
|
||||
"""
|
||||
import collections
|
||||
import importlib
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import OrderedDict
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
@@ -47,7 +48,7 @@ class BasePlugin:
|
||||
Initializes the BasePlugin.
|
||||
"""
|
||||
|
||||
def register(self, cfg):
|
||||
def register(self, cfg): # pylint: disable=unused-argument
|
||||
"""
|
||||
Registers the plugin with the given configuration.
|
||||
|
||||
@@ -63,7 +64,7 @@ class BasePlugin:
|
||||
Returns a pydantic model for the plugin's input arguments.
|
||||
"""
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
def pre_model_load(self, cfg): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions before the model is loaded.
|
||||
|
||||
@@ -74,7 +75,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_model_load(self, cfg, model):
|
||||
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after the model is loaded.
|
||||
|
||||
@@ -86,7 +87,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def pre_lora_load(self, cfg, model):
|
||||
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions before LoRA weights are loaded.
|
||||
|
||||
@@ -98,7 +99,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_lora_load(self, cfg, model):
|
||||
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after LoRA weights are loaded.
|
||||
|
||||
@@ -110,7 +111,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def create_optimizer(self, cfg, trainer):
|
||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
"""
|
||||
Creates and returns an optimizer for training.
|
||||
|
||||
@@ -122,7 +123,9 @@ class BasePlugin:
|
||||
object: The created optimizer.
|
||||
"""
|
||||
|
||||
def create_lr_scheduler(self, cfg, trainer, optimizer):
|
||||
def create_lr_scheduler(
|
||||
self, cfg, trainer, optimizer
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Creates and returns a learning rate scheduler.
|
||||
|
||||
@@ -135,7 +138,7 @@ class BasePlugin:
|
||||
object: The created learning rate scheduler.
|
||||
"""
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg, model):
|
||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Adds callbacks to the trainer before training.
|
||||
|
||||
@@ -146,8 +149,11 @@ class BasePlugin:
|
||||
Returns:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||
"""
|
||||
return []
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
def add_callbacks_post_trainer(
|
||||
self, cfg, trainer
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Adds callbacks to the trainer after training.
|
||||
|
||||
@@ -158,8 +164,9 @@ class BasePlugin:
|
||||
Returns:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||
"""
|
||||
return []
|
||||
|
||||
def post_train(self, cfg, model):
|
||||
def post_train(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after training is complete.
|
||||
|
||||
@@ -171,7 +178,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
def post_train_unload(self, cfg): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after training is complete and the model is unloaded.
|
||||
|
||||
@@ -227,7 +234,7 @@ class PluginManager:
|
||||
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
|
||||
"""
|
||||
|
||||
plugins: List[BasePlugin] = []
|
||||
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
|
||||
|
||||
_instance = None
|
||||
|
||||
@@ -237,7 +244,7 @@ class PluginManager:
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(PluginManager, cls).__new__(cls)
|
||||
cls._instance.plugins: List[BasePlugin] = []
|
||||
cls._instance.plugins = collections.OrderedDict()
|
||||
return cls._instance
|
||||
|
||||
@staticmethod
|
||||
@@ -265,7 +272,7 @@ class PluginManager:
|
||||
"""
|
||||
try:
|
||||
plugin = load_plugin(plugin_name)
|
||||
self.plugins.append(plugin)
|
||||
self.plugins[plugin_name] = plugin
|
||||
except ImportError:
|
||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
||||
|
||||
@@ -277,7 +284,7 @@ class PluginManager:
|
||||
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
|
||||
"""
|
||||
input_args = []
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
input_args_from_plugin = plugin.get_input_args()
|
||||
if input_args_from_plugin is not None:
|
||||
input_args.append(input_args_from_plugin)
|
||||
@@ -293,7 +300,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.pre_model_load(cfg)
|
||||
|
||||
def post_model_load(self, cfg, model):
|
||||
@@ -307,7 +314,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_model_load(cfg, model)
|
||||
|
||||
def pre_lora_load(self, cfg, model):
|
||||
@@ -321,7 +328,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.pre_lora_load(cfg, model)
|
||||
|
||||
def post_lora_load(self, cfg, model):
|
||||
@@ -335,7 +342,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_lora_load(cfg, model)
|
||||
|
||||
def create_optimizer(self, cfg, trainer):
|
||||
@@ -349,7 +356,7 @@ class PluginManager:
|
||||
Returns:
|
||||
object: The created optimizer, or None if none was found.
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
optimizer = plugin.create_optimizer(cfg, trainer)
|
||||
if optimizer is not None:
|
||||
return optimizer
|
||||
@@ -367,7 +374,7 @@ class PluginManager:
|
||||
Returns:
|
||||
object: The created learning rate scheduler, or None if none was found.
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
|
||||
if scheduler is not None:
|
||||
return scheduler
|
||||
@@ -385,7 +392,7 @@ class PluginManager:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
||||
"""
|
||||
callbacks = []
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
|
||||
return callbacks
|
||||
|
||||
@@ -401,7 +408,7 @@ class PluginManager:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
||||
"""
|
||||
callbacks = []
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
||||
return callbacks
|
||||
|
||||
@@ -416,5 +423,5 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_train_unload(cfg)
|
||||
|
||||
Reference in New Issue
Block a user