From 170cdb5be94809b871d2f5aea2352b7d94966308 Mon Sep 17 00:00:00 2001 From: divyanshuaggarwal Date: Mon, 28 Apr 2025 19:40:28 +0530 Subject: [PATCH] Add Post_model_load, post_lora_load, post_train, post_train_unload function calls (#2539) * Update train.py add post_model_load and post_lora_load model calss. * Update train.py add post_train and post_train_unload function calls * Update train.py * Update base.py * Update train.py * chore: lint * clarify plugin hooks * Update src/axolotl/integrations/base.py Co-authored-by: Dan Saunders * Update src/axolotl/utils/models.py Co-authored-by: Dan Saunders * Update src/axolotl/utils/models.py Co-authored-by: Dan Saunders * Update src/axolotl/integrations/base.py Co-authored-by: Dan Saunders * Update models.py * Update models.py * remove extra call to post_model_load * chore: lint * add test for hooks and gc trainer * disable duplicated code check for test * fix the path and add better handling --------- Co-authored-by: Wing Lian Co-authored-by: Dan Saunders --- src/axolotl/cli/train.py | 4 + src/axolotl/integrations/base.py | 40 +++++- src/axolotl/train.py | 4 + src/axolotl/utils/models.py | 19 ++- tests/e2e/integrations/test_hooks.py | 184 +++++++++++++++++++++++++++ tests/e2e/test_lora_llama.py | 4 +- 6 files changed, 245 insertions(+), 10 deletions(-) create mode 100644 tests/e2e/integrations/test_hooks.py diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index e225141b6..4f258313d 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -1,5 +1,6 @@ """CLI to run training on a model.""" +import gc import logging import os from pathlib import Path @@ -48,8 +49,11 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta) + del model, tokenizer, trainer + gc.collect() + plugin_manager = PluginManager.get_instance() plugin_manager.post_train_unload(cfg) diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 11015e31a..cb65f96dd 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -36,9 +36,10 @@ class BasePlugin: Methods: register(cfg): Registers the plugin with the given configuration. pre_model_load(cfg): Performs actions before the model is loaded. - post_model_load(cfg, model): Performs actions after the model is loaded. + post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied. pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. + post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters. create_optimizer(cfg, trainer): Creates and returns an optimizer for training. create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler. add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. @@ -77,6 +78,14 @@ class BasePlugin: None """ + def post_model_build(self, cfg, model): # pylint: disable=unused-argument + """ + Performs actions after the model is built/loaded, but before any adapters are applied. + + Args: + cfg (dict): The configuration for the plugin. + """ + def post_model_load(self, cfg, model): # pylint: disable=unused-argument """ Performs actions after the model is loaded. @@ -329,9 +338,22 @@ class PluginManager: for plugin in self.plugins.values(): plugin.pre_model_load(cfg) + def post_model_build(self, cfg, model): + """ + Calls the post_model_build method of all registered plugins after the model has been built/loaded, + but before any adapters have been applied. + + Args: + cfg (dict): The configuration for the plugins. + model (object): The loaded model. + """ + for plugin in self.plugins.values(): + plugin.post_model_build(cfg, model) + def post_model_load(self, cfg, model): """ - Calls the post_model_load method of all registered plugins. + Calls the post_model_load method of all registered plugins after the model has been loaded + inclusive of any adapters Parameters: cfg (dict): The configuration for the plugins. @@ -458,6 +480,20 @@ class PluginManager: callbacks.extend(plugin_callbacks) return callbacks + def post_train(self, cfg, model): + """ + Calls the post_train method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + model (object): The loaded model. + + Returns: + None + """ + for plugin in self.plugins.values(): + plugin.post_train(cfg, model) + def post_train_unload(self, cfg): """ Calls the post_train_unload method of all registered plugins. diff --git a/src/axolotl/train.py b/src/axolotl/train.py index d116ea4fd..7896239de 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -29,6 +29,7 @@ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuil from axolotl.core.trainers.mixins.sequence_parallel import ( SequenceParallelContextManager, ) +from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed @@ -533,4 +534,7 @@ def train( if not cfg.use_ray: cleanup_distributed() + plugin_manager = PluginManager.get_instance() + plugin_manager.post_train(cfg, model) + return model, tokenizer, trainer diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d7105daba..ab4cc19bb 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -53,6 +53,7 @@ from transformers.integrations.deepspeed import ( ) from axolotl.common.architectures import MOE_ARCH_BLOCK +from axolotl.integrations.base import PluginManager from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, @@ -74,6 +75,7 @@ from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant LOG = logging.getLogger(__name__) +PLUGIN_MANAGER = PluginManager.get_instance() MULTIMODAL_AUTO_MODEL_MAPPING = { "mllama": MllamaForConditionalGeneration, @@ -571,10 +573,8 @@ class ModelLoader: patch_gemma3conditionalgeneration_forward() # load any patches from plugins - from axolotl.integrations.base import PluginManager - plugin_manager = PluginManager.get_instance() - plugin_manager.pre_model_load(self.cfg) + PLUGIN_MANAGER.pre_model_load(self.cfg) # monkey patch to allow additional Accelerator init kwargs if self.cfg.fp8: @@ -1252,6 +1252,7 @@ class ModelLoader: try: skip_move_to_device = self.build_model(qlora_fsdp) + PLUGIN_MANAGER.post_model_build(self.cfg, self.model) except Exception as err: # pylint: disable=broad-exception-caught LOG.exception(err) raise err @@ -1331,6 +1332,8 @@ class ModelLoader: before_kbit_train_or_finetune=False, ) + PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model) + # --------------------------------------------------------- # load lora or adapter # --------------------------------------------------------- @@ -1392,7 +1395,7 @@ class ModelLoader: gc.collect() torch.cuda.empty_cache() - # TODO resume_from_checkpoint handling + PLUGIN_MANAGER.post_model_load(self.cfg, self.model) return self.model, lora_config @@ -1427,9 +1430,13 @@ def load_adapter(model, cfg, adapter, inference=False): if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() if adapter in ["lora", "qlora"]: - return load_lora(model, cfg, inference=inference) + model, lora_config = load_lora(model, cfg, inference=inference) + PLUGIN_MANAGER.post_lora_load(cfg, model) + return model, lora_config if adapter == "llama-adapter": - return load_llama_adapter(model, cfg) + model, lora_config = load_llama_adapter(model, cfg) + PLUGIN_MANAGER.post_lora_load(cfg, model) + return model, lora_config raise NotImplementedError(f"{adapter} peft adapter not available") diff --git a/tests/e2e/integrations/test_hooks.py b/tests/e2e/integrations/test_hooks.py new file mode 100644 index 000000000..e51334dfe --- /dev/null +++ b/tests/e2e/integrations/test_hooks.py @@ -0,0 +1,184 @@ +""" +e2e tests to make sure all the hooks are fired on the plugin +""" + +import os +from pathlib import Path + +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets +from axolotl.integrations.base import BasePlugin +from axolotl.train import train +from axolotl.utils.config import normalize_config, prepare_plugins, validate_config +from axolotl.utils.dict import DictDefault + +from ..utils import check_model_output_exists + + +class LogHooksPlugin(BasePlugin): + """ + fixture to capture in a log file each hook that was fired + """ + + base_dir = Path("/tmp/axolotl-log-hooks") + + def __init__(self): + self.base_dir.mkdir(parents=True, exist_ok=True) + try: + os.remove(self.base_dir.joinpath("plugin_hooks.log")) + except FileNotFoundError: + pass + + def pre_model_load(self, cfg): # pylint: disable=unused-argument + with open( + self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" + ) as f: + f.write("pre_model_load\n") + + def post_model_build(self, cfg, model): # pylint: disable=unused-argument + with open( + self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" + ) as f: + f.write("post_model_build\n") + + def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument + with open( + self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" + ) as f: + f.write("pre_lora_load\n") + + def post_lora_load(self, cfg, model): # pylint: disable=unused-argument + with open( + self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" + ) as f: + f.write("post_lora_load\n") + + def post_model_load(self, cfg, model): # pylint: disable=unused-argument + with open( + self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" + ) as f: + f.write("post_model_load\n") + + def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument + with open( + self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" + ) as f: + f.write("create_optimizer\n") + + def get_trainer_cls(self, cfg): # pylint: disable=unused-argument + with open( + self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" + ) as f: + f.write("get_trainer_cls\n") + + def create_lr_scheduler( + self, cfg, trainer, optimizer + ): # pylint: disable=unused-argument + with open( + self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" + ) as f: + f.write("create_lr_scheduler\n") + + def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument + with open( + self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" + ) as f: + f.write("add_callbacks_pre_trainer\n") + return [] + + def add_callbacks_post_trainer( + self, cfg, trainer + ): # pylint: disable=unused-argument + with open( + self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" + ) as f: + f.write("add_callbacks_post_trainer\n") + return [] + + def post_train(self, cfg, model): # pylint: disable=unused-argument + with open( + self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" + ) as f: + f.write("post_train\n") + + def post_train_unload(self, cfg): # pylint: disable=unused-argument + with open( + self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" + ) as f: + f.write("post_train_unload\n") + + +class TestPluginHooks: + """ + e2e tests to make sure all the hooks are fired during the training + """ + + def test_plugin_hooks(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "plugins": [ + "tests.e2e.integrations.test_hooks.LogHooksPlugin", + ], + "tokenizer_type": "AutoTokenizer", + "sequence_len": 1024, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.02, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "max_steps": 5, + "flash_attention": True, + "bf16": "auto", + } + ) + + cfg = validate_config(cfg) + prepare_plugins(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + + with open( + "/tmp/axolotl-log-hooks" + "/plugin_hooks.log", "r", encoding="utf-8" + ) as f: + file_contents = f.readlines() + file_contents = "\n".join(file_contents) + assert "pre_model_load" in file_contents + assert "post_model_build" in file_contents + assert "pre_lora_load" in file_contents + assert "post_lora_load" in file_contents + assert "post_model_load" in file_contents + # assert "create_optimizer" in file_contents # not implemented yet + assert "get_trainer_cls" in file_contents + # assert "create_lr_scheduler" in file_contents # not implemented yet + assert "add_callbacks_pre_trainer" in file_contents + assert "add_callbacks_post_trainer" in file_contents + assert "post_train" in file_contents + # assert "post_train_unload" in file_contents # not called from test train call + + try: + os.remove("/tmp/axolotl-log-hooks" + "/plugin_hooks.log") + except FileNotFoundError: + pass diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index e5a734b33..b02fe3d44 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -48,13 +48,13 @@ class TestLoraLlama(unittest.TestCase): }, ], "num_epochs": 1, - "micro_batch_size": 8, + "micro_batch_size": 2, "gradient_accumulation_steps": 1, "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", - "max_steps": 20, + "max_steps": 5, } )