move the plugin post trainer create to the setup trainer (#2907)
* move the plugin post trainer create to the setup trainer * move post-train plugins to execute-training fn
This commit is contained in:
@@ -224,6 +224,9 @@ def execute_training(
|
|||||||
# torch.set_default_dtype(torch.bfloat16)
|
# torch.set_default_dtype(torch.bfloat16)
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
plugin_manager.post_train(cfg, trainer.model)
|
||||||
|
|
||||||
|
|
||||||
def save_trained_model(
|
def save_trained_model(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
@@ -510,6 +513,9 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
|||||||
peft_config=peft_config,
|
peft_config=peft_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
plugin_manager.post_trainer_create(cfg, trainer)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
trainer,
|
trainer,
|
||||||
model,
|
model,
|
||||||
@@ -541,9 +547,6 @@ def train(
|
|||||||
processor,
|
processor,
|
||||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
|
||||||
plugin_manager.post_trainer_create(cfg, trainer)
|
|
||||||
|
|
||||||
# Handle untrained tokens if configured
|
# Handle untrained tokens if configured
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
train_dataset = dataset_meta.train_dataset
|
train_dataset = dataset_meta.train_dataset
|
||||||
@@ -566,6 +569,4 @@ def train(
|
|||||||
if not cfg.use_ray:
|
if not cfg.use_ray:
|
||||||
cleanup_distributed()
|
cleanup_distributed()
|
||||||
|
|
||||||
plugin_manager.post_train(cfg, model)
|
|
||||||
|
|
||||||
return model, tokenizer, trainer
|
return model, tokenizer, trainer
|
||||||
|
|||||||
Reference in New Issue
Block a user