diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 35c58501c..967179903 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -224,6 +224,9 @@ def execute_training( # torch.set_default_dtype(torch.bfloat16) trainer.train(resume_from_checkpoint=resume_from_checkpoint) + plugin_manager = PluginManager.get_instance() + plugin_manager.post_train(cfg, trainer.model) + def save_trained_model( cfg: DictDefault, @@ -510,6 +513,9 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> peft_config=peft_config, ) + plugin_manager = PluginManager.get_instance() + plugin_manager.post_trainer_create(cfg, trainer) + return ( trainer, model, @@ -541,9 +547,6 @@ def train( processor, ) = setup_model_and_trainer(cfg, dataset_meta) - plugin_manager = PluginManager.get_instance() - plugin_manager.post_trainer_create(cfg, trainer) - # Handle untrained tokens if configured safe_serialization = cfg.save_safetensors is True train_dataset = dataset_meta.train_dataset @@ -566,6 +569,4 @@ def train( if not cfg.use_ray: cleanup_distributed() - plugin_manager.post_train(cfg, model) - return model, tokenizer, trainer