move the plugin post trainer create to the setup trainer (#2907)
* move the plugin post trainer create to the setup trainer * move post-train plugins to execute-training fn
This commit is contained in:
@@ -224,6 +224,9 @@ def execute_training(
|
||||
# torch.set_default_dtype(torch.bfloat16)
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_train(cfg, trainer.model)
|
||||
|
||||
|
||||
def save_trained_model(
|
||||
cfg: DictDefault,
|
||||
@@ -510,6 +513,9 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_trainer_create(cfg, trainer)
|
||||
|
||||
return (
|
||||
trainer,
|
||||
model,
|
||||
@@ -541,9 +547,6 @@ def train(
|
||||
processor,
|
||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_trainer_create(cfg, trainer)
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
@@ -566,6 +569,4 @@ def train(
|
||||
if not cfg.use_ray:
|
||||
cleanup_distributed()
|
||||
|
||||
plugin_manager.post_train(cfg, model)
|
||||
|
||||
return model, tokenizer, trainer
|
||||
|
||||
Reference in New Issue
Block a user