allow plugins to return their own dataset (#2617) [skip ci]

* allow plugins to return their own dataset

* add post_trainer_create and wire up

* add hook check

* address PR feedback:

* remove annotation causing circular import
This commit is contained in:
Wing Lian
2025-05-06 20:05:51 -04:00
committed by GitHub
parent 0b140fef83
commit cd84325253
5 changed files with 129 additions and 50 deletions

View File

@@ -18,6 +18,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching
@@ -47,7 +48,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching():
if cfg.rl:
plugin_manager = PluginManager.get_instance()
if plugin_manager.load_datasets(cfg, preprocess=True):
pass
elif cfg.rl:
load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -43,10 +43,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
plugin_manager = PluginManager.get_instance()
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
if not dataset_meta:
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)

View File

@@ -26,6 +26,8 @@ from typing import OrderedDict
import torch
from torch.optim.lr_scheduler import LRScheduler
from axolotl.utils.dict import DictDefault
class BasePlugin:
"""
@@ -36,11 +38,13 @@ class BasePlugin:
Methods:
register(cfg): Registers the plugin with the given configuration.
load_datasets(cfg): Loads and preprocesses the dataset for training.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
post_trainer_create(cfg, trainer): Performs actions after the trainer is created.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
@@ -63,20 +67,32 @@ class BasePlugin:
None
"""
def get_input_args(self):
def get_input_args(self) -> str | None:
"""
Returns a pydantic model for the plugin's input arguments.
"""
def load_datasets(self, cfg: DictDefault, preprocess: bool = False):
"""
Loads and preprocesses the dataset for training.
Args:
cfg: The configuration for the plugin.
preprocess: Whether this is the preprocess step of the datasets.
Returns:
dataset_meta: The metadata for the training dataset.
"""
def pre_model_load(self, cfg): # pylint: disable=unused-argument
"""
Performs actions before the model is loaded.
Parameters:
cfg (dict): The configuration for the plugin.
Args:
cfg (dict): The configuration for the plugin.
Returns:
None
None
"""
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
@@ -91,59 +107,71 @@ class BasePlugin:
"""
Performs actions after the model is loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
None
"""
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions before LoRA weights are loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
None
"""
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after LoRA weights are loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
None
"""
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
"""
Returns a custom class for the trainer.
Parameters:
cfg (dict): The global axolotl configuration.
Args:
cfg (dict): The global axolotl configuration.
Returns:
class: The class for the trainer.
class: The class for the trainer.
"""
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
"""
Performs actions after the trainer is created.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
None
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer.
object: The created optimizer.
"""
def create_lr_scheduler(
@@ -152,26 +180,26 @@ class BasePlugin:
"""
Creates and returns a learning rate scheduler.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps
Returns:
object (LRScheduler): The created learning rate scheduler.
object (LRScheduler): The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
"""
setup callbacks before creating the trainer.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
@@ -182,12 +210,12 @@ class BasePlugin:
Adds callbacks to the trainer after creating the trainer.
This is useful for callbacks that require access to the model or trainer.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added
List[callable]: A list of callback functions to be added
"""
return []
@@ -195,23 +223,23 @@ class BasePlugin:
"""
Performs actions after training is complete.
Parameters:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Args:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Returns:
None
None
"""
def post_train_unload(self, cfg): # pylint: disable=unused-argument
"""
Performs actions after training is complete and the model is unloaded.
Parameters:
cfg (dict): The configuration for the plugin.
Args:
cfg (dict): The configuration for the plugin.
Returns:
None
None
"""
@@ -338,6 +366,27 @@ class PluginManager:
input_args.append(input_args_from_plugin)
return input_args
def load_datasets(self, cfg, preprocess: bool = False):
"""
Calls the load_datasets method of each registered plugin.
Args:
cfg: The configuration for the plugins.
preprocess : Whether this is preprocess step of the datasets.
Returns:
dataset_meta: The dataset metadata loaded from all registered plugins.
"""
return_ds_meta = None
for plugin in self.plugins.values():
dataset_meta = plugin.load_datasets(cfg, preprocess)
if dataset_meta is not None:
if return_ds_meta is None:
return_ds_meta = dataset_meta
else:
raise RuntimeError("Multiple plugins loaded datasets")
return return_ds_meta
def pre_model_load(self, cfg):
"""
Calls the pre_model_load method of all registered plugins.
@@ -422,6 +471,20 @@ class PluginManager:
return trainer_cls
return None
def post_trainer_create(self, cfg, trainer):
"""
Calls the post_trainer_create method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_trainer_create(cfg, trainer)
def create_optimizer(self, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.

View File

@@ -528,6 +528,9 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer)
# Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True
train_dataset = dataset_meta.train_dataset
@@ -550,7 +553,6 @@ def train(
if not cfg.use_ray:
cleanup_distributed()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train(cfg, model)
return model, tokenizer, trainer

View File

@@ -29,6 +29,12 @@ class LogHooksPlugin(BasePlugin):
except FileNotFoundError:
pass
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_trainer_create\n")
def pre_model_load(self, cfg): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
@@ -165,6 +171,7 @@ class TestPluginHooks:
) as f:
file_contents = f.readlines()
file_contents = "\n".join(file_contents)
assert "post_trainer_create" in file_contents
assert "pre_model_load" in file_contents
assert "post_model_build" in file_contents
assert "pre_lora_load" in file_contents