cleanup the deepspeed proxy model at the end of training (#1675)
This commit is contained in:
@@ -197,6 +197,13 @@ def train(
|
|||||||
trainer.accelerator.wait_for_everyone()
|
trainer.accelerator.wait_for_everyone()
|
||||||
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
|
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
|
||||||
|
|
||||||
|
# the trainer saved a model.safetensors file in the output directory,
|
||||||
|
# but it is a proxy model and should be deleted
|
||||||
|
if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")):
|
||||||
|
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
|
||||||
|
LOG.info("This is a proxy model and should be deleted")
|
||||||
|
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||||
|
|
||||||
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
||||||
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
||||||
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
||||||
|
|||||||
Reference in New Issue
Block a user