allow plugins to return their own dataset (#2617) [skip ci]

* allow plugins to return their own dataset

* add post_trainer_create and wire up

* add hook check

* address PR feedback:

* remove annotation causing circular import
This commit is contained in:
Wing Lian
2025-05-06 20:05:51 -04:00
committed by GitHub
parent 0b140fef83
commit cd84325253
5 changed files with 129 additions and 50 deletions

View File

@@ -18,6 +18,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching from axolotl.utils.trainer import disable_datasets_caching
@@ -47,7 +48,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching(): with disable_datasets_caching():
if cfg.rl: plugin_manager = PluginManager.get_instance()
if plugin_manager.load_datasets(cfg, preprocess=True):
pass
elif cfg.rl:
load_preference_datasets(cfg=cfg, cli_args=cli_args) load_preference_datasets(cfg=cfg, cli_args=cli_args)
else: else:
load_datasets(cfg=cfg, cli_args=cli_args) load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -43,10 +43,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
if int(os.getenv("LOCAL_RANK", "0")) == 0: if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token() check_user_token()
if cfg.rl: plugin_manager = PluginManager.get_instance()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
else: if not dataset_meta:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta) model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)

View File

@@ -26,6 +26,8 @@ from typing import OrderedDict
import torch import torch
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from axolotl.utils.dict import DictDefault
class BasePlugin: class BasePlugin:
""" """
@@ -36,11 +38,13 @@ class BasePlugin:
Methods: Methods:
register(cfg): Registers the plugin with the given configuration. register(cfg): Registers the plugin with the given configuration.
load_datasets(cfg): Loads and preprocesses the dataset for training.
pre_model_load(cfg): Performs actions before the model is loaded. pre_model_load(cfg): Performs actions before the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied. post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters. post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
post_trainer_create(cfg, trainer): Performs actions after the trainer is created.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training. create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler. create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
@@ -63,20 +67,32 @@ class BasePlugin:
None None
""" """
def get_input_args(self): def get_input_args(self) -> str | None:
""" """
Returns a pydantic model for the plugin's input arguments. Returns a pydantic model for the plugin's input arguments.
""" """
def load_datasets(self, cfg: DictDefault, preprocess: bool = False):
"""
Loads and preprocesses the dataset for training.
Args:
cfg: The configuration for the plugin.
preprocess: Whether this is the preprocess step of the datasets.
Returns:
dataset_meta: The metadata for the training dataset.
"""
def pre_model_load(self, cfg): # pylint: disable=unused-argument def pre_model_load(self, cfg): # pylint: disable=unused-argument
""" """
Performs actions before the model is loaded. Performs actions before the model is loaded.
Parameters: Args:
cfg (dict): The configuration for the plugin. cfg (dict): The configuration for the plugin.
Returns: Returns:
None None
""" """
def post_model_build(self, cfg, model): # pylint: disable=unused-argument def post_model_build(self, cfg, model): # pylint: disable=unused-argument
@@ -91,59 +107,71 @@ class BasePlugin:
""" """
Performs actions after the model is loaded. Performs actions after the model is loaded.
Parameters: Args:
cfg (dict): The configuration for the plugin. cfg (dict): The configuration for the plugin.
model (object): The loaded model. model (object): The loaded model.
Returns: Returns:
None None
""" """
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
""" """
Performs actions before LoRA weights are loaded. Performs actions before LoRA weights are loaded.
Parameters: Args:
cfg (dict): The configuration for the plugin. cfg (dict): The configuration for the plugin.
model (object): The loaded model. model (object): The loaded model.
Returns: Returns:
None None
""" """
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
""" """
Performs actions after LoRA weights are loaded. Performs actions after LoRA weights are loaded.
Parameters: Args:
cfg (dict): The configuration for the plugin. cfg (dict): The configuration for the plugin.
model (object): The loaded model. model (object): The loaded model.
Returns: Returns:
None None
""" """
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument): def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
""" """
Returns a custom class for the trainer. Returns a custom class for the trainer.
Parameters: Args:
cfg (dict): The global axolotl configuration. cfg (dict): The global axolotl configuration.
Returns: Returns:
class: The class for the trainer. class: The class for the trainer.
"""
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
"""
Performs actions after the trainer is created.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
None
""" """
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
""" """
Creates and returns an optimizer for training. Creates and returns an optimizer for training.
Parameters: Args:
cfg (dict): The configuration for the plugin. cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training. trainer (object): The trainer object for training.
Returns: Returns:
object: The created optimizer. object: The created optimizer.
""" """
def create_lr_scheduler( def create_lr_scheduler(
@@ -152,26 +180,26 @@ class BasePlugin:
""" """
Creates and returns a learning rate scheduler. Creates and returns a learning rate scheduler.
Parameters: Args:
cfg (dict): The configuration for the plugin. cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training. trainer (object): The trainer object for training.
optimizer (object): The optimizer for training. optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps num_training_steps (int): Total number of training steps
Returns: Returns:
object (LRScheduler): The created learning rate scheduler. object (LRScheduler): The created learning rate scheduler.
""" """
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
""" """
setup callbacks before creating the trainer. setup callbacks before creating the trainer.
Parameters: Args:
cfg (dict): The configuration for the plugin. cfg (dict): The configuration for the plugin.
model (object): The loaded model. model (object): The loaded model.
Returns: Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs List[callable]: A list of callback functions to be added to the TrainingArgs
""" """
return [] return []
@@ -182,12 +210,12 @@ class BasePlugin:
Adds callbacks to the trainer after creating the trainer. Adds callbacks to the trainer after creating the trainer.
This is useful for callbacks that require access to the model or trainer. This is useful for callbacks that require access to the model or trainer.
Parameters: Args:
cfg (dict): The configuration for the plugin. cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training. trainer (object): The trainer object for training.
Returns: Returns:
List[callable]: A list of callback functions to be added List[callable]: A list of callback functions to be added
""" """
return [] return []
@@ -195,23 +223,23 @@ class BasePlugin:
""" """
Performs actions after training is complete. Performs actions after training is complete.
Parameters: Args:
cfg (dict): The axolotl configuration cfg (dict): The axolotl configuration
model (object): The loaded model. model (object): The loaded model.
Returns: Returns:
None None
""" """
def post_train_unload(self, cfg): # pylint: disable=unused-argument def post_train_unload(self, cfg): # pylint: disable=unused-argument
""" """
Performs actions after training is complete and the model is unloaded. Performs actions after training is complete and the model is unloaded.
Parameters: Args:
cfg (dict): The configuration for the plugin. cfg (dict): The configuration for the plugin.
Returns: Returns:
None None
""" """
@@ -338,6 +366,27 @@ class PluginManager:
input_args.append(input_args_from_plugin) input_args.append(input_args_from_plugin)
return input_args return input_args
def load_datasets(self, cfg, preprocess: bool = False):
"""
Calls the load_datasets method of each registered plugin.
Args:
cfg: The configuration for the plugins.
preprocess : Whether this is preprocess step of the datasets.
Returns:
dataset_meta: The dataset metadata loaded from all registered plugins.
"""
return_ds_meta = None
for plugin in self.plugins.values():
dataset_meta = plugin.load_datasets(cfg, preprocess)
if dataset_meta is not None:
if return_ds_meta is None:
return_ds_meta = dataset_meta
else:
raise RuntimeError("Multiple plugins loaded datasets")
return return_ds_meta
def pre_model_load(self, cfg): def pre_model_load(self, cfg):
""" """
Calls the pre_model_load method of all registered plugins. Calls the pre_model_load method of all registered plugins.
@@ -422,6 +471,20 @@ class PluginManager:
return trainer_cls return trainer_cls
return None return None
def post_trainer_create(self, cfg, trainer):
"""
Calls the post_trainer_create method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_trainer_create(cfg, trainer)
def create_optimizer(self, trainer): def create_optimizer(self, trainer):
""" """
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer. Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.

View File

@@ -528,6 +528,9 @@ def train(
processor, processor,
) = setup_model_and_trainer(cfg, dataset_meta) ) = setup_model_and_trainer(cfg, dataset_meta)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer)
# Handle untrained tokens if configured # Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True
train_dataset = dataset_meta.train_dataset train_dataset = dataset_meta.train_dataset
@@ -550,7 +553,6 @@ def train(
if not cfg.use_ray: if not cfg.use_ray:
cleanup_distributed() cleanup_distributed()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train(cfg, model) plugin_manager.post_train(cfg, model)
return model, tokenizer, trainer return model, tokenizer, trainer

View File

@@ -29,6 +29,12 @@ class LogHooksPlugin(BasePlugin):
except FileNotFoundError: except FileNotFoundError:
pass pass
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_trainer_create\n")
def pre_model_load(self, cfg): # pylint: disable=unused-argument def pre_model_load(self, cfg): # pylint: disable=unused-argument
with open( with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
@@ -165,6 +171,7 @@ class TestPluginHooks:
) as f: ) as f:
file_contents = f.readlines() file_contents = f.readlines()
file_contents = "\n".join(file_contents) file_contents = "\n".join(file_contents)
assert "post_trainer_create" in file_contents
assert "pre_model_load" in file_contents assert "pre_model_load" in file_contents
assert "post_model_build" in file_contents assert "post_model_build" in file_contents
assert "pre_lora_load" in file_contents assert "pre_lora_load" in file_contents