fix FSDP save of final model (#329)

This commit is contained in:
Wing Lian
2023-07-30 21:46:44 -04:00
committed by GitHub
parent 41a4d15d43
commit 894cba09f3

View File

@@ -344,7 +344,9 @@ def train(
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0:
if cfg.fsdp:
model.save_pretrained(cfg.output_dir)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)