Add Post_model_load, post_lora_load, post_train, post_train_unload function calls (#2539)
* Update train.py add post_model_load and post_lora_load model calss. * Update train.py add post_train and post_train_unload function calls * Update train.py * Update base.py * Update train.py * chore: lint * clarify plugin hooks * Update src/axolotl/integrations/base.py Co-authored-by: Dan Saunders <danjsaund@gmail.com> * Update src/axolotl/utils/models.py Co-authored-by: Dan Saunders <danjsaund@gmail.com> * Update src/axolotl/utils/models.py Co-authored-by: Dan Saunders <danjsaund@gmail.com> * Update src/axolotl/integrations/base.py Co-authored-by: Dan Saunders <danjsaund@gmail.com> * Update models.py * Update models.py * remove extra call to post_model_load * chore: lint * add test for hooks and gc trainer * disable duplicated code check for test * fix the path and add better handling --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Dan Saunders <danjsaund@gmail.com>
This commit is contained in:
committed by
GitHub
parent
5d182a1056
commit
170cdb5be9
@@ -1,5 +1,6 @@
|
||||
"""CLI to run training on a model."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -48,8 +49,11 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
del model, tokenizer, trainer
|
||||
|
||||
gc.collect()
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
@@ -36,9 +36,10 @@ class BasePlugin:
|
||||
Methods:
|
||||
register(cfg): Registers the plugin with the given configuration.
|
||||
pre_model_load(cfg): Performs actions before the model is loaded.
|
||||
post_model_load(cfg, model): Performs actions after the model is loaded.
|
||||
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
|
||||
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
|
||||
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
|
||||
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
|
||||
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
|
||||
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
|
||||
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
|
||||
@@ -77,6 +78,14 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after the model is built/loaded, but before any adapters are applied.
|
||||
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
"""
|
||||
|
||||
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after the model is loaded.
|
||||
@@ -329,9 +338,22 @@ class PluginManager:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.pre_model_load(cfg)
|
||||
|
||||
def post_model_build(self, cfg, model):
|
||||
"""
|
||||
Calls the post_model_build method of all registered plugins after the model has been built/loaded,
|
||||
but before any adapters have been applied.
|
||||
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
model (object): The loaded model.
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
def post_model_load(self, cfg, model):
|
||||
"""
|
||||
Calls the post_model_load method of all registered plugins.
|
||||
Calls the post_model_load method of all registered plugins after the model has been loaded
|
||||
inclusive of any adapters
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
@@ -458,6 +480,20 @@ class PluginManager:
|
||||
callbacks.extend(plugin_callbacks)
|
||||
return callbacks
|
||||
|
||||
def post_train(self, cfg, model):
|
||||
"""
|
||||
Calls the post_train method of all registered plugins.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_train(cfg, model)
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
"""
|
||||
Calls the post_train_unload method of all registered plugins.
|
||||
|
||||
@@ -29,6 +29,7 @@ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuil
|
||||
from axolotl.core.trainers.mixins.sequence_parallel import (
|
||||
SequenceParallelContextManager,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
@@ -533,4 +534,7 @@ def train(
|
||||
if not cfg.use_ray:
|
||||
cleanup_distributed()
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_train(cfg, model)
|
||||
|
||||
return model, tokenizer, trainer
|
||||
|
||||
@@ -53,6 +53,7 @@ from transformers.integrations.deepspeed import (
|
||||
)
|
||||
|
||||
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
@@ -74,6 +75,7 @@ from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
||||
"mllama": MllamaForConditionalGeneration,
|
||||
@@ -571,10 +573,8 @@ class ModelLoader:
|
||||
patch_gemma3conditionalgeneration_forward()
|
||||
|
||||
# load any patches from plugins
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.pre_model_load(self.cfg)
|
||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||
|
||||
# monkey patch to allow additional Accelerator init kwargs
|
||||
if self.cfg.fp8:
|
||||
@@ -1252,6 +1252,7 @@ class ModelLoader:
|
||||
|
||||
try:
|
||||
skip_move_to_device = self.build_model(qlora_fsdp)
|
||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||
except Exception as err: # pylint: disable=broad-exception-caught
|
||||
LOG.exception(err)
|
||||
raise err
|
||||
@@ -1331,6 +1332,8 @@ class ModelLoader:
|
||||
before_kbit_train_or_finetune=False,
|
||||
)
|
||||
|
||||
PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# load lora or adapter
|
||||
# ---------------------------------------------------------
|
||||
@@ -1392,7 +1395,7 @@ class ModelLoader:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
PLUGIN_MANAGER.post_model_load(self.cfg, self.model)
|
||||
return self.model, lora_config
|
||||
|
||||
|
||||
@@ -1427,9 +1430,13 @@ def load_adapter(model, cfg, adapter, inference=False):
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
if adapter in ["lora", "qlora"]:
|
||||
return load_lora(model, cfg, inference=inference)
|
||||
model, lora_config = load_lora(model, cfg, inference=inference)
|
||||
PLUGIN_MANAGER.post_lora_load(cfg, model)
|
||||
return model, lora_config
|
||||
if adapter == "llama-adapter":
|
||||
return load_llama_adapter(model, cfg)
|
||||
model, lora_config = load_llama_adapter(model, cfg)
|
||||
PLUGIN_MANAGER.post_lora_load(cfg, model)
|
||||
return model, lora_config
|
||||
|
||||
raise NotImplementedError(f"{adapter} peft adapter not available")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user