Add Post_model_load, post_lora_load, post_train, post_train_unload function calls (#2539)
* Update train.py add post_model_load and post_lora_load model calss. * Update train.py add post_train and post_train_unload function calls * Update train.py * Update base.py * Update train.py * chore: lint * clarify plugin hooks * Update src/axolotl/integrations/base.py Co-authored-by: Dan Saunders <danjsaund@gmail.com> * Update src/axolotl/utils/models.py Co-authored-by: Dan Saunders <danjsaund@gmail.com> * Update src/axolotl/utils/models.py Co-authored-by: Dan Saunders <danjsaund@gmail.com> * Update src/axolotl/integrations/base.py Co-authored-by: Dan Saunders <danjsaund@gmail.com> * Update models.py * Update models.py * remove extra call to post_model_load * chore: lint * add test for hooks and gc trainer * disable duplicated code check for test * fix the path and add better handling --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Dan Saunders <danjsaund@gmail.com>
This commit is contained in:
committed by
GitHub
parent
5d182a1056
commit
170cdb5be9
@@ -1,5 +1,6 @@
|
||||
"""CLI to run training on a model."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -48,8 +49,11 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
del model, tokenizer, trainer
|
||||
|
||||
gc.collect()
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
@@ -36,9 +36,10 @@ class BasePlugin:
|
||||
Methods:
|
||||
register(cfg): Registers the plugin with the given configuration.
|
||||
pre_model_load(cfg): Performs actions before the model is loaded.
|
||||
post_model_load(cfg, model): Performs actions after the model is loaded.
|
||||
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
|
||||
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
|
||||
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
|
||||
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
|
||||
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
|
||||
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
|
||||
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
|
||||
@@ -77,6 +78,14 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after the model is built/loaded, but before any adapters are applied.
|
||||
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
"""
|
||||
|
||||
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after the model is loaded.
|
||||
@@ -329,9 +338,22 @@ class PluginManager:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.pre_model_load(cfg)
|
||||
|
||||
def post_model_build(self, cfg, model):
|
||||
"""
|
||||
Calls the post_model_build method of all registered plugins after the model has been built/loaded,
|
||||
but before any adapters have been applied.
|
||||
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
model (object): The loaded model.
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
def post_model_load(self, cfg, model):
|
||||
"""
|
||||
Calls the post_model_load method of all registered plugins.
|
||||
Calls the post_model_load method of all registered plugins after the model has been loaded
|
||||
inclusive of any adapters
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
@@ -458,6 +480,20 @@ class PluginManager:
|
||||
callbacks.extend(plugin_callbacks)
|
||||
return callbacks
|
||||
|
||||
def post_train(self, cfg, model):
|
||||
"""
|
||||
Calls the post_train method of all registered plugins.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_train(cfg, model)
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
"""
|
||||
Calls the post_train_unload method of all registered plugins.
|
||||
|
||||
@@ -29,6 +29,7 @@ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuil
|
||||
from axolotl.core.trainers.mixins.sequence_parallel import (
|
||||
SequenceParallelContextManager,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
@@ -533,4 +534,7 @@ def train(
|
||||
if not cfg.use_ray:
|
||||
cleanup_distributed()
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_train(cfg, model)
|
||||
|
||||
return model, tokenizer, trainer
|
||||
|
||||
@@ -53,6 +53,7 @@ from transformers.integrations.deepspeed import (
|
||||
)
|
||||
|
||||
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
@@ -74,6 +75,7 @@ from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
||||
"mllama": MllamaForConditionalGeneration,
|
||||
@@ -571,10 +573,8 @@ class ModelLoader:
|
||||
patch_gemma3conditionalgeneration_forward()
|
||||
|
||||
# load any patches from plugins
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.pre_model_load(self.cfg)
|
||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||
|
||||
# monkey patch to allow additional Accelerator init kwargs
|
||||
if self.cfg.fp8:
|
||||
@@ -1252,6 +1252,7 @@ class ModelLoader:
|
||||
|
||||
try:
|
||||
skip_move_to_device = self.build_model(qlora_fsdp)
|
||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||
except Exception as err: # pylint: disable=broad-exception-caught
|
||||
LOG.exception(err)
|
||||
raise err
|
||||
@@ -1331,6 +1332,8 @@ class ModelLoader:
|
||||
before_kbit_train_or_finetune=False,
|
||||
)
|
||||
|
||||
PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# load lora or adapter
|
||||
# ---------------------------------------------------------
|
||||
@@ -1392,7 +1395,7 @@ class ModelLoader:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
PLUGIN_MANAGER.post_model_load(self.cfg, self.model)
|
||||
return self.model, lora_config
|
||||
|
||||
|
||||
@@ -1427,9 +1430,13 @@ def load_adapter(model, cfg, adapter, inference=False):
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
if adapter in ["lora", "qlora"]:
|
||||
return load_lora(model, cfg, inference=inference)
|
||||
model, lora_config = load_lora(model, cfg, inference=inference)
|
||||
PLUGIN_MANAGER.post_lora_load(cfg, model)
|
||||
return model, lora_config
|
||||
if adapter == "llama-adapter":
|
||||
return load_llama_adapter(model, cfg)
|
||||
model, lora_config = load_llama_adapter(model, cfg)
|
||||
PLUGIN_MANAGER.post_lora_load(cfg, model)
|
||||
return model, lora_config
|
||||
|
||||
raise NotImplementedError(f"{adapter} peft adapter not available")
|
||||
|
||||
|
||||
184
tests/e2e/integrations/test_hooks.py
Normal file
184
tests/e2e/integrations/test_hooks.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
e2e tests to make sure all the hooks are fired on the plugin
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists
|
||||
|
||||
|
||||
class LogHooksPlugin(BasePlugin):
|
||||
"""
|
||||
fixture to capture in a log file each hook that was fired
|
||||
"""
|
||||
|
||||
base_dir = Path("/tmp/axolotl-log-hooks")
|
||||
|
||||
def __init__(self):
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
os.remove(self.base_dir.joinpath("plugin_hooks.log"))
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def pre_model_load(self, cfg): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("pre_model_load\n")
|
||||
|
||||
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_model_build\n")
|
||||
|
||||
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("pre_lora_load\n")
|
||||
|
||||
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_lora_load\n")
|
||||
|
||||
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_model_load\n")
|
||||
|
||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("create_optimizer\n")
|
||||
|
||||
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("get_trainer_cls\n")
|
||||
|
||||
def create_lr_scheduler(
|
||||
self, cfg, trainer, optimizer
|
||||
): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("create_lr_scheduler\n")
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("add_callbacks_pre_trainer\n")
|
||||
return []
|
||||
|
||||
def add_callbacks_post_trainer(
|
||||
self, cfg, trainer
|
||||
): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("add_callbacks_post_trainer\n")
|
||||
return []
|
||||
|
||||
def post_train(self, cfg, model): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_train\n")
|
||||
|
||||
def post_train_unload(self, cfg): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_train_unload\n")
|
||||
|
||||
|
||||
class TestPluginHooks:
|
||||
"""
|
||||
e2e tests to make sure all the hooks are fired during the training
|
||||
"""
|
||||
|
||||
def test_plugin_hooks(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"plugins": [
|
||||
"tests.e2e.integrations.test_hooks.LogHooksPlugin",
|
||||
],
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"flash_attention": True,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
with open(
|
||||
"/tmp/axolotl-log-hooks" + "/plugin_hooks.log", "r", encoding="utf-8"
|
||||
) as f:
|
||||
file_contents = f.readlines()
|
||||
file_contents = "\n".join(file_contents)
|
||||
assert "pre_model_load" in file_contents
|
||||
assert "post_model_build" in file_contents
|
||||
assert "pre_lora_load" in file_contents
|
||||
assert "post_lora_load" in file_contents
|
||||
assert "post_model_load" in file_contents
|
||||
# assert "create_optimizer" in file_contents # not implemented yet
|
||||
assert "get_trainer_cls" in file_contents
|
||||
# assert "create_lr_scheduler" in file_contents # not implemented yet
|
||||
assert "add_callbacks_pre_trainer" in file_contents
|
||||
assert "add_callbacks_post_trainer" in file_contents
|
||||
assert "post_train" in file_contents
|
||||
# assert "post_train_unload" in file_contents # not called from test train call
|
||||
|
||||
try:
|
||||
os.remove("/tmp/axolotl-log-hooks" + "/plugin_hooks.log")
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
@@ -48,13 +48,13 @@ class TestLoraLlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"max_steps": 5,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user