diff --git a/scripts/finetune.py b/scripts/finetune.py index 78eca05b9..df70e1c83 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -310,6 +310,20 @@ def train( model = BetterTransformer.reverse(model) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + if cfg.adapter is not None: + # pylint: disable=duplicate-code + LOG.info("running merge of LoRA with base model") + model = model.merge_and_unload() + model.to(dtype=torch.float16) + + if cfg.local_rank == 0: + LOG.info("saving merged model") + model.save_pretrained( + str(Path(cfg.output_dir) / "merged"), + safe_serialization=safe_serialization, + ) + tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) + if __name__ == "__main__": fire.Fire(train)