diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 5585c88a7..2a4dcd288 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -18,6 +18,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.datasets import load_datasets, load_preference_datasets +from axolotl.integrations.base import PluginManager from axolotl.utils.dict import DictDefault from axolotl.utils.trainer import disable_datasets_caching @@ -47,7 +48,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH with disable_datasets_caching(): - if cfg.rl: + plugin_manager = PluginManager.get_instance() + if plugin_manager.load_datasets(cfg, preprocess=True): + pass + elif cfg.rl: load_preference_datasets(cfg=cfg, cli_args=cli_args) else: load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 9e90cede3..777d84885 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -43,10 +43,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): if int(os.getenv("LOCAL_RANK", "0")) == 0: check_user_token() - if cfg.rl: - dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - else: - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + plugin_manager = PluginManager.get_instance() + dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False) + if not dataset_meta: + if cfg.rl: + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) + else: + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta) diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index efe542af7..97cbac693 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -26,6 +26,8 @@ from typing import OrderedDict import torch from torch.optim.lr_scheduler import LRScheduler +from axolotl.utils.dict import DictDefault + class BasePlugin: """ @@ -36,11 +38,13 @@ class BasePlugin: Methods: register(cfg): Registers the plugin with the given configuration. + load_datasets(cfg): Loads and preprocesses the dataset for training. pre_model_load(cfg): Performs actions before the model is loaded. post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied. pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters. + post_trainer_create(cfg, trainer): Performs actions after the trainer is created. create_optimizer(cfg, trainer): Creates and returns an optimizer for training. create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler. add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. @@ -63,20 +67,32 @@ class BasePlugin: None """ - def get_input_args(self): + def get_input_args(self) -> str | None: """ Returns a pydantic model for the plugin's input arguments. """ + def load_datasets(self, cfg: DictDefault, preprocess: bool = False): + """ + Loads and preprocesses the dataset for training. + + Args: + cfg: The configuration for the plugin. + preprocess: Whether this is the preprocess step of the datasets. + + Returns: + dataset_meta: The metadata for the training dataset. + """ + def pre_model_load(self, cfg): # pylint: disable=unused-argument """ Performs actions before the model is loaded. - Parameters: - cfg (dict): The configuration for the plugin. + Args: + cfg (dict): The configuration for the plugin. Returns: - None + None """ def post_model_build(self, cfg, model): # pylint: disable=unused-argument @@ -91,59 +107,71 @@ class BasePlugin: """ Performs actions after the model is loaded. - Parameters: - cfg (dict): The configuration for the plugin. - model (object): The loaded model. + Args: + cfg (dict): The configuration for the plugin. + model (object): The loaded model. Returns: - None + None """ def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument """ Performs actions before LoRA weights are loaded. - Parameters: - cfg (dict): The configuration for the plugin. - model (object): The loaded model. + Args: + cfg (dict): The configuration for the plugin. + model (object): The loaded model. Returns: - None + None """ def post_lora_load(self, cfg, model): # pylint: disable=unused-argument """ Performs actions after LoRA weights are loaded. - Parameters: - cfg (dict): The configuration for the plugin. - model (object): The loaded model. + Args: + cfg (dict): The configuration for the plugin. + model (object): The loaded model. Returns: - None + None """ def get_trainer_cls(self, cfg): # pylint: disable=unused-argument): """ Returns a custom class for the trainer. - Parameters: - cfg (dict): The global axolotl configuration. + Args: + cfg (dict): The global axolotl configuration. Returns: - class: The class for the trainer. + class: The class for the trainer. + """ + + def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument + """ + Performs actions after the trainer is created. + + Args: + cfg (dict): The configuration for the plugin. + trainer (object): The trainer object for training. + + Returns: + None """ def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument """ Creates and returns an optimizer for training. - Parameters: - cfg (dict): The configuration for the plugin. - trainer (object): The trainer object for training. + Args: + cfg (dict): The configuration for the plugin. + trainer (object): The trainer object for training. Returns: - object: The created optimizer. + object: The created optimizer. """ def create_lr_scheduler( @@ -152,26 +180,26 @@ class BasePlugin: """ Creates and returns a learning rate scheduler. - Parameters: - cfg (dict): The configuration for the plugin. - trainer (object): The trainer object for training. - optimizer (object): The optimizer for training. - num_training_steps (int): Total number of training steps + Args: + cfg (dict): The configuration for the plugin. + trainer (object): The trainer object for training. + optimizer (object): The optimizer for training. + num_training_steps (int): Total number of training steps Returns: - object (LRScheduler): The created learning rate scheduler. + object (LRScheduler): The created learning rate scheduler. """ def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument """ setup callbacks before creating the trainer. - Parameters: - cfg (dict): The configuration for the plugin. - model (object): The loaded model. + Args: + cfg (dict): The configuration for the plugin. + model (object): The loaded model. Returns: - List[callable]: A list of callback functions to be added to the TrainingArgs + List[callable]: A list of callback functions to be added to the TrainingArgs """ return [] @@ -182,12 +210,12 @@ class BasePlugin: Adds callbacks to the trainer after creating the trainer. This is useful for callbacks that require access to the model or trainer. - Parameters: - cfg (dict): The configuration for the plugin. - trainer (object): The trainer object for training. + Args: + cfg (dict): The configuration for the plugin. + trainer (object): The trainer object for training. Returns: - List[callable]: A list of callback functions to be added + List[callable]: A list of callback functions to be added """ return [] @@ -195,23 +223,23 @@ class BasePlugin: """ Performs actions after training is complete. - Parameters: - cfg (dict): The axolotl configuration - model (object): The loaded model. + Args: + cfg (dict): The axolotl configuration + model (object): The loaded model. Returns: - None + None """ def post_train_unload(self, cfg): # pylint: disable=unused-argument """ Performs actions after training is complete and the model is unloaded. - Parameters: - cfg (dict): The configuration for the plugin. + Args: + cfg (dict): The configuration for the plugin. Returns: - None + None """ @@ -338,6 +366,27 @@ class PluginManager: input_args.append(input_args_from_plugin) return input_args + def load_datasets(self, cfg, preprocess: bool = False): + """ + Calls the load_datasets method of each registered plugin. + + Args: + cfg: The configuration for the plugins. + preprocess : Whether this is preprocess step of the datasets. + + Returns: + dataset_meta: The dataset metadata loaded from all registered plugins. + """ + return_ds_meta = None + for plugin in self.plugins.values(): + dataset_meta = plugin.load_datasets(cfg, preprocess) + if dataset_meta is not None: + if return_ds_meta is None: + return_ds_meta = dataset_meta + else: + raise RuntimeError("Multiple plugins loaded datasets") + return return_ds_meta + def pre_model_load(self, cfg): """ Calls the pre_model_load method of all registered plugins. @@ -422,6 +471,20 @@ class PluginManager: return trainer_cls return None + def post_trainer_create(self, cfg, trainer): + """ + Calls the post_trainer_create method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + trainer (object): The trainer object for training. + + Returns: + None + """ + for plugin in self.plugins.values(): + plugin.post_trainer_create(cfg, trainer) + def create_optimizer(self, trainer): """ Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer. diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 7ffd3f883..e58eddbff 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -513,6 +513,9 @@ def train( processor, ) = setup_model_and_trainer(cfg, dataset_meta) + plugin_manager = PluginManager.get_instance() + plugin_manager.post_trainer_create(cfg, trainer) + # Handle untrained tokens if configured safe_serialization = cfg.save_safetensors is True train_dataset = dataset_meta.train_dataset @@ -535,7 +538,6 @@ def train( if not cfg.use_ray: cleanup_distributed() - plugin_manager = PluginManager.get_instance() plugin_manager.post_train(cfg, model) return model, tokenizer, trainer diff --git a/tests/e2e/integrations/test_hooks.py b/tests/e2e/integrations/test_hooks.py index 9b12e6d4e..45d7200fb 100644 --- a/tests/e2e/integrations/test_hooks.py +++ b/tests/e2e/integrations/test_hooks.py @@ -29,6 +29,12 @@ class LogHooksPlugin(BasePlugin): except FileNotFoundError: pass + def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument + with open( + self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" + ) as f: + f.write("post_trainer_create\n") + def pre_model_load(self, cfg): # pylint: disable=unused-argument with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" @@ -165,6 +171,7 @@ class TestPluginHooks: ) as f: file_contents = f.readlines() file_contents = "\n".join(file_contents) + assert "post_trainer_create" in file_contents assert "pre_model_load" in file_contents assert "post_model_build" in file_contents assert "pre_lora_load" in file_contents