Add Post_model_load, post_lora_load, post_train, post_train_unload function calls (#2539)

* Update train.py

add post_model_load and post_lora_load model calss.

* Update train.py

add post_train and post_train_unload function calls

* Update train.py

* Update base.py

* Update train.py

* chore: lint

* clarify plugin hooks

* Update src/axolotl/integrations/base.py

Co-authored-by: Dan Saunders <danjsaund@gmail.com>

* Update src/axolotl/utils/models.py

Co-authored-by: Dan Saunders <danjsaund@gmail.com>

* Update src/axolotl/utils/models.py

Co-authored-by: Dan Saunders <danjsaund@gmail.com>

* Update src/axolotl/integrations/base.py

Co-authored-by: Dan Saunders <danjsaund@gmail.com>

* Update models.py

* Update models.py

* remove extra call to post_model_load

* chore: lint

* add test for hooks and gc trainer

* disable duplicated code check for test

* fix the path and add better handling

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Dan Saunders <danjsaund@gmail.com>
This commit is contained in:
divyanshuaggarwal
2025-04-28 19:40:28 +05:30
committed by GitHub
parent 5d182a1056
commit 170cdb5be9
6 changed files with 245 additions and 10 deletions

View File

@@ -1,5 +1,6 @@
"""CLI to run training on a model."""
import gc
import logging
import os
from pathlib import Path
@@ -48,8 +49,11 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
del model, tokenizer, trainer
gc.collect()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train_unload(cfg)

View File

@@ -36,9 +36,10 @@ class BasePlugin:
Methods:
register(cfg): Registers the plugin with the given configuration.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_load(cfg, model): Performs actions after the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
@@ -77,6 +78,14 @@ class BasePlugin:
None
"""
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is built/loaded, but before any adapters are applied.
Args:
cfg (dict): The configuration for the plugin.
"""
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is loaded.
@@ -329,9 +338,22 @@ class PluginManager:
for plugin in self.plugins.values():
plugin.pre_model_load(cfg)
def post_model_build(self, cfg, model):
"""
Calls the post_model_build method of all registered plugins after the model has been built/loaded,
but before any adapters have been applied.
Args:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_model_build(cfg, model)
def post_model_load(self, cfg, model):
"""
Calls the post_model_load method of all registered plugins.
Calls the post_model_load method of all registered plugins after the model has been loaded
inclusive of any adapters
Parameters:
cfg (dict): The configuration for the plugins.
@@ -458,6 +480,20 @@ class PluginManager:
callbacks.extend(plugin_callbacks)
return callbacks
def post_train(self, cfg, model):
"""
Calls the post_train method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_train(cfg, model)
def post_train_unload(self, cfg):
"""
Calls the post_train_unload method of all registered plugins.

View File

@@ -29,6 +29,7 @@ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuil
from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
@@ -533,4 +534,7 @@ def train(
if not cfg.use_ray:
cleanup_distributed()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train(cfg, model)
return model, tokenizer, trainer

View File

@@ -53,6 +53,7 @@ from transformers.integrations.deepspeed import (
)
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.integrations.base import PluginManager
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -74,6 +75,7 @@ from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
LOG = logging.getLogger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
MULTIMODAL_AUTO_MODEL_MAPPING = {
"mllama": MllamaForConditionalGeneration,
@@ -571,10 +573,8 @@ class ModelLoader:
patch_gemma3conditionalgeneration_forward()
# load any patches from plugins
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
PLUGIN_MANAGER.pre_model_load(self.cfg)
# monkey patch to allow additional Accelerator init kwargs
if self.cfg.fp8:
@@ -1252,6 +1252,7 @@ class ModelLoader:
try:
skip_move_to_device = self.build_model(qlora_fsdp)
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
except Exception as err: # pylint: disable=broad-exception-caught
LOG.exception(err)
raise err
@@ -1331,6 +1332,8 @@ class ModelLoader:
before_kbit_train_or_finetune=False,
)
PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model)
# ---------------------------------------------------------
# load lora or adapter
# ---------------------------------------------------------
@@ -1392,7 +1395,7 @@ class ModelLoader:
gc.collect()
torch.cuda.empty_cache()
# TODO resume_from_checkpoint handling
PLUGIN_MANAGER.post_model_load(self.cfg, self.model)
return self.model, lora_config
@@ -1427,9 +1430,13 @@ def load_adapter(model, cfg, adapter, inference=False):
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference)
model, lora_config = load_lora(model, cfg, inference=inference)
PLUGIN_MANAGER.post_lora_load(cfg, model)
return model, lora_config
if adapter == "llama-adapter":
return load_llama_adapter(model, cfg)
model, lora_config = load_llama_adapter(model, cfg)
PLUGIN_MANAGER.post_lora_load(cfg, model)
return model, lora_config
raise NotImplementedError(f"{adapter} peft adapter not available")

View File

@@ -0,0 +1,184 @@
"""
e2e tests to make sure all the hooks are fired on the plugin
"""
import os
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.integrations.base import BasePlugin
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
class LogHooksPlugin(BasePlugin):
"""
fixture to capture in a log file each hook that was fired
"""
base_dir = Path("/tmp/axolotl-log-hooks")
def __init__(self):
self.base_dir.mkdir(parents=True, exist_ok=True)
try:
os.remove(self.base_dir.joinpath("plugin_hooks.log"))
except FileNotFoundError:
pass
def pre_model_load(self, cfg): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("pre_model_load\n")
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_model_build\n")
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("pre_lora_load\n")
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_lora_load\n")
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_model_load\n")
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("create_optimizer\n")
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("get_trainer_cls\n")
def create_lr_scheduler(
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("create_lr_scheduler\n")
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("add_callbacks_pre_trainer\n")
return []
def add_callbacks_post_trainer(
self, cfg, trainer
): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("add_callbacks_post_trainer\n")
return []
def post_train(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_train\n")
def post_train_unload(self, cfg): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_train_unload\n")
class TestPluginHooks:
"""
e2e tests to make sure all the hooks are fired during the training
"""
def test_plugin_hooks(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"tests.e2e.integrations.test_hooks.LogHooksPlugin",
],
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"flash_attention": True,
"bf16": "auto",
}
)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
with open(
"/tmp/axolotl-log-hooks" + "/plugin_hooks.log", "r", encoding="utf-8"
) as f:
file_contents = f.readlines()
file_contents = "\n".join(file_contents)
assert "pre_model_load" in file_contents
assert "post_model_build" in file_contents
assert "pre_lora_load" in file_contents
assert "post_lora_load" in file_contents
assert "post_model_load" in file_contents
# assert "create_optimizer" in file_contents # not implemented yet
assert "get_trainer_cls" in file_contents
# assert "create_lr_scheduler" in file_contents # not implemented yet
assert "add_callbacks_pre_trainer" in file_contents
assert "add_callbacks_post_trainer" in file_contents
assert "post_train" in file_contents
# assert "post_train_unload" in file_contents # not called from test train call
try:
os.remove("/tmp/axolotl-log-hooks" + "/plugin_hooks.log")
except FileNotFoundError:
pass

View File

@@ -48,13 +48,13 @@ class TestLoraLlama(unittest.TestCase):
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"max_steps": 5,
}
)