From a276c9c88dfb27ea42165d6445de1a3b913c93cb Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 13 Aug 2023 01:22:52 +0900 Subject: [PATCH] Fix(save): Save as safetensors (#363) --- scripts/finetune.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 635c06930..e6fea4456 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -257,6 +257,8 @@ def train( LOG.info("loading model and peft_config...") model, peft_config = load_model(cfg, tokenizer) + safe_serialization = cfg.save_safetensors is True + if "merge_lora" in kwargs and cfg.adapter is not None: LOG.info("running merge of LoRA with base model") model = model.merge_and_unload() @@ -264,7 +266,10 @@ def train( if cfg.local_rank == 0: LOG.info("saving merged model") - model.save_pretrained(str(Path(cfg.output_dir) / "merged")) + model.save_pretrained( + str(Path(cfg.output_dir) / "merged"), + safe_serialization=safe_serialization, + ) tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) return @@ -280,7 +285,7 @@ def train( return if "shard" in kwargs: - model.save_pretrained(cfg.output_dir) + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) return trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) @@ -302,7 +307,7 @@ def train( def terminate_handler(_, __, model): if cfg.flash_optimum: model = BetterTransformer.reverse(model) - model.save_pretrained(cfg.output_dir) + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) sys.exit(0) signal.signal( @@ -342,11 +347,11 @@ def train( # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.fsdp: - model.save_pretrained(cfg.output_dir) + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) elif cfg.local_rank == 0: if cfg.flash_optimum: model = BetterTransformer.reverse(model) - model.save_pretrained(cfg.output_dir) + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time