Add Post_model_load, post_lora_load, post_train, post_train_unload function calls (#2539)
* Update train.py add post_model_load and post_lora_load model calss. * Update train.py add post_train and post_train_unload function calls * Update train.py * Update base.py * Update train.py * chore: lint * clarify plugin hooks * Update src/axolotl/integrations/base.py Co-authored-by: Dan Saunders <danjsaund@gmail.com> * Update src/axolotl/utils/models.py Co-authored-by: Dan Saunders <danjsaund@gmail.com> * Update src/axolotl/utils/models.py Co-authored-by: Dan Saunders <danjsaund@gmail.com> * Update src/axolotl/integrations/base.py Co-authored-by: Dan Saunders <danjsaund@gmail.com> * Update models.py * Update models.py * remove extra call to post_model_load * chore: lint * add test for hooks and gc trainer * disable duplicated code check for test * fix the path and add better handling --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Dan Saunders <danjsaund@gmail.com>
This commit is contained in:
committed by
GitHub
parent
5d182a1056
commit
170cdb5be9
@@ -1,5 +1,6 @@
|
|||||||
"""CLI to run training on a model."""
|
"""CLI to run training on a model."""
|
||||||
|
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -48,8 +49,11 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
del model, tokenizer, trainer
|
del model, tokenizer, trainer
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
plugin_manager.post_train_unload(cfg)
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|||||||
@@ -36,9 +36,10 @@ class BasePlugin:
|
|||||||
Methods:
|
Methods:
|
||||||
register(cfg): Registers the plugin with the given configuration.
|
register(cfg): Registers the plugin with the given configuration.
|
||||||
pre_model_load(cfg): Performs actions before the model is loaded.
|
pre_model_load(cfg): Performs actions before the model is loaded.
|
||||||
post_model_load(cfg, model): Performs actions after the model is loaded.
|
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
|
||||||
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
|
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
|
||||||
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
|
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
|
||||||
|
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
|
||||||
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
|
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
|
||||||
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
|
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
|
||||||
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
|
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
|
||||||
@@ -77,6 +78,14 @@ class BasePlugin:
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
|
||||||
|
"""
|
||||||
|
Performs actions after the model is built/loaded, but before any adapters are applied.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict): The configuration for the plugin.
|
||||||
|
"""
|
||||||
|
|
||||||
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
|
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Performs actions after the model is loaded.
|
Performs actions after the model is loaded.
|
||||||
@@ -329,9 +338,22 @@ class PluginManager:
|
|||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
plugin.pre_model_load(cfg)
|
plugin.pre_model_load(cfg)
|
||||||
|
|
||||||
|
def post_model_build(self, cfg, model):
|
||||||
|
"""
|
||||||
|
Calls the post_model_build method of all registered plugins after the model has been built/loaded,
|
||||||
|
but before any adapters have been applied.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict): The configuration for the plugins.
|
||||||
|
model (object): The loaded model.
|
||||||
|
"""
|
||||||
|
for plugin in self.plugins.values():
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
def post_model_load(self, cfg, model):
|
def post_model_load(self, cfg, model):
|
||||||
"""
|
"""
|
||||||
Calls the post_model_load method of all registered plugins.
|
Calls the post_model_load method of all registered plugins after the model has been loaded
|
||||||
|
inclusive of any adapters
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
cfg (dict): The configuration for the plugins.
|
cfg (dict): The configuration for the plugins.
|
||||||
@@ -458,6 +480,20 @@ class PluginManager:
|
|||||||
callbacks.extend(plugin_callbacks)
|
callbacks.extend(plugin_callbacks)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
|
def post_train(self, cfg, model):
|
||||||
|
"""
|
||||||
|
Calls the post_train method of all registered plugins.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg (dict): The configuration for the plugins.
|
||||||
|
model (object): The loaded model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
for plugin in self.plugins.values():
|
||||||
|
plugin.post_train(cfg, model)
|
||||||
|
|
||||||
def post_train_unload(self, cfg):
|
def post_train_unload(self, cfg):
|
||||||
"""
|
"""
|
||||||
Calls the post_train_unload method of all registered plugins.
|
Calls the post_train_unload method of all registered plugins.
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuil
|
|||||||
from axolotl.core.trainers.mixins.sequence_parallel import (
|
from axolotl.core.trainers.mixins.sequence_parallel import (
|
||||||
SequenceParallelContextManager,
|
SequenceParallelContextManager,
|
||||||
)
|
)
|
||||||
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import cleanup_distributed
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
@@ -533,4 +534,7 @@ def train(
|
|||||||
if not cfg.use_ray:
|
if not cfg.use_ray:
|
||||||
cleanup_distributed()
|
cleanup_distributed()
|
||||||
|
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
plugin_manager.post_train(cfg, model)
|
||||||
|
|
||||||
return model, tokenizer, trainer
|
return model, tokenizer, trainer
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ from transformers.integrations.deepspeed import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||||
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.monkeypatch.multipack import (
|
from axolotl.monkeypatch.multipack import (
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||||
@@ -74,6 +75,7 @@ from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
|||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||||
|
|
||||||
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
||||||
"mllama": MllamaForConditionalGeneration,
|
"mllama": MllamaForConditionalGeneration,
|
||||||
@@ -571,10 +573,8 @@ class ModelLoader:
|
|||||||
patch_gemma3conditionalgeneration_forward()
|
patch_gemma3conditionalgeneration_forward()
|
||||||
|
|
||||||
# load any patches from plugins
|
# load any patches from plugins
|
||||||
from axolotl.integrations.base import PluginManager
|
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||||
plugin_manager.pre_model_load(self.cfg)
|
|
||||||
|
|
||||||
# monkey patch to allow additional Accelerator init kwargs
|
# monkey patch to allow additional Accelerator init kwargs
|
||||||
if self.cfg.fp8:
|
if self.cfg.fp8:
|
||||||
@@ -1252,6 +1252,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
skip_move_to_device = self.build_model(qlora_fsdp)
|
skip_move_to_device = self.build_model(qlora_fsdp)
|
||||||
|
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||||
except Exception as err: # pylint: disable=broad-exception-caught
|
except Exception as err: # pylint: disable=broad-exception-caught
|
||||||
LOG.exception(err)
|
LOG.exception(err)
|
||||||
raise err
|
raise err
|
||||||
@@ -1331,6 +1332,8 @@ class ModelLoader:
|
|||||||
before_kbit_train_or_finetune=False,
|
before_kbit_train_or_finetune=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model)
|
||||||
|
|
||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
# load lora or adapter
|
# load lora or adapter
|
||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
@@ -1392,7 +1395,7 @@ class ModelLoader:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
PLUGIN_MANAGER.post_model_load(self.cfg, self.model)
|
||||||
return self.model, lora_config
|
return self.model, lora_config
|
||||||
|
|
||||||
|
|
||||||
@@ -1427,9 +1430,13 @@ def load_adapter(model, cfg, adapter, inference=False):
|
|||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
model.enable_input_require_grads()
|
model.enable_input_require_grads()
|
||||||
if adapter in ["lora", "qlora"]:
|
if adapter in ["lora", "qlora"]:
|
||||||
return load_lora(model, cfg, inference=inference)
|
model, lora_config = load_lora(model, cfg, inference=inference)
|
||||||
|
PLUGIN_MANAGER.post_lora_load(cfg, model)
|
||||||
|
return model, lora_config
|
||||||
if adapter == "llama-adapter":
|
if adapter == "llama-adapter":
|
||||||
return load_llama_adapter(model, cfg)
|
model, lora_config = load_llama_adapter(model, cfg)
|
||||||
|
PLUGIN_MANAGER.post_lora_load(cfg, model)
|
||||||
|
return model, lora_config
|
||||||
|
|
||||||
raise NotImplementedError(f"{adapter} peft adapter not available")
|
raise NotImplementedError(f"{adapter} peft adapter not available")
|
||||||
|
|
||||||
|
|||||||
184
tests/e2e/integrations/test_hooks.py
Normal file
184
tests/e2e/integrations/test_hooks.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
"""
|
||||||
|
e2e tests to make sure all the hooks are fired on the plugin
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from ..utils import check_model_output_exists
|
||||||
|
|
||||||
|
|
||||||
|
class LogHooksPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
fixture to capture in a log file each hook that was fired
|
||||||
|
"""
|
||||||
|
|
||||||
|
base_dir = Path("/tmp/axolotl-log-hooks")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
os.remove(self.base_dir.joinpath("plugin_hooks.log"))
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def pre_model_load(self, cfg): # pylint: disable=unused-argument
|
||||||
|
with open(
|
||||||
|
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write("pre_model_load\n")
|
||||||
|
|
||||||
|
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
|
||||||
|
with open(
|
||||||
|
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write("post_model_build\n")
|
||||||
|
|
||||||
|
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||||
|
with open(
|
||||||
|
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write("pre_lora_load\n")
|
||||||
|
|
||||||
|
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||||
|
with open(
|
||||||
|
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write("post_lora_load\n")
|
||||||
|
|
||||||
|
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
|
||||||
|
with open(
|
||||||
|
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write("post_model_load\n")
|
||||||
|
|
||||||
|
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||||
|
with open(
|
||||||
|
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write("create_optimizer\n")
|
||||||
|
|
||||||
|
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument
|
||||||
|
with open(
|
||||||
|
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write("get_trainer_cls\n")
|
||||||
|
|
||||||
|
def create_lr_scheduler(
|
||||||
|
self, cfg, trainer, optimizer
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
with open(
|
||||||
|
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write("create_lr_scheduler\n")
|
||||||
|
|
||||||
|
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||||
|
with open(
|
||||||
|
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write("add_callbacks_pre_trainer\n")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def add_callbacks_post_trainer(
|
||||||
|
self, cfg, trainer
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
with open(
|
||||||
|
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write("add_callbacks_post_trainer\n")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def post_train(self, cfg, model): # pylint: disable=unused-argument
|
||||||
|
with open(
|
||||||
|
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write("post_train\n")
|
||||||
|
|
||||||
|
def post_train_unload(self, cfg): # pylint: disable=unused-argument
|
||||||
|
with open(
|
||||||
|
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write("post_train_unload\n")
|
||||||
|
|
||||||
|
|
||||||
|
class TestPluginHooks:
|
||||||
|
"""
|
||||||
|
e2e tests to make sure all the hooks are fired during the training
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_plugin_hooks(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"plugins": [
|
||||||
|
"tests.e2e.integrations.test_hooks.LogHooksPlugin",
|
||||||
|
],
|
||||||
|
"tokenizer_type": "AutoTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.02,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"flash_attention": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
|
with open(
|
||||||
|
"/tmp/axolotl-log-hooks" + "/plugin_hooks.log", "r", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
file_contents = f.readlines()
|
||||||
|
file_contents = "\n".join(file_contents)
|
||||||
|
assert "pre_model_load" in file_contents
|
||||||
|
assert "post_model_build" in file_contents
|
||||||
|
assert "pre_lora_load" in file_contents
|
||||||
|
assert "post_lora_load" in file_contents
|
||||||
|
assert "post_model_load" in file_contents
|
||||||
|
# assert "create_optimizer" in file_contents # not implemented yet
|
||||||
|
assert "get_trainer_cls" in file_contents
|
||||||
|
# assert "create_lr_scheduler" in file_contents # not implemented yet
|
||||||
|
assert "add_callbacks_pre_trainer" in file_contents
|
||||||
|
assert "add_callbacks_post_trainer" in file_contents
|
||||||
|
assert "post_train" in file_contents
|
||||||
|
# assert "post_train_unload" in file_contents # not called from test train call
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.remove("/tmp/axolotl-log-hooks" + "/plugin_hooks.log")
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
@@ -48,13 +48,13 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"micro_batch_size": 8,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"max_steps": 20,
|
"max_steps": 5,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user