From f20c8deff178bf0caba4caf0144dd1f2aa6e6e68 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 22 Aug 2023 23:15:09 -0400 Subject: [PATCH] merge lora on train completion --- scripts/finetune.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/scripts/finetune.py b/scripts/finetune.py index 78eca05b9..df70e1c83 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -310,6 +310,20 @@ def train( model = BetterTransformer.reverse(model) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + if cfg.adapter is not None: + # pylint: disable=duplicate-code + LOG.info("running merge of LoRA with base model") + model = model.merge_and_unload() + model.to(dtype=torch.float16) + + if cfg.local_rank == 0: + LOG.info("saving merged model") + model.save_pretrained( + str(Path(cfg.output_dir) / "merged"), + safe_serialization=safe_serialization, + ) + tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) + if __name__ == "__main__": fire.Fire(train)