From 53e739f11e29d554c8a37b0877a34f89f5447394 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 22 Aug 2023 23:17:44 -0400 Subject: [PATCH] deduplicate code --- scripts/finetune.py | 39 ++++++++++++++++----------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index df70e1c83..fd5fafbe1 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -155,6 +155,20 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b return not any(el in list2 for el in list1) +def merge_lora(model, tokenizer, cfg): + LOG.info("running merge of LoRA with base model") + model = model.merge_and_unload() + model.to(dtype=torch.float16) + + if cfg.local_rank == 0: + LOG.info("saving merged model") + model.save_pretrained( + str(Path(cfg.output_dir) / "merged"), + safe_serialization=cfg.save_safetensors is True, + ) + tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) + + def train( config: Path = Path("configs/"), prepare_ds_only: bool = False, @@ -214,17 +228,7 @@ def train( safe_serialization = cfg.save_safetensors is True if "merge_lora" in kwargs and cfg.adapter is not None: - LOG.info("running merge of LoRA with base model") - model = model.merge_and_unload() - model.to(dtype=torch.float16) - - if cfg.local_rank == 0: - LOG.info("saving merged model") - model.save_pretrained( - str(Path(cfg.output_dir) / "merged"), - safe_serialization=safe_serialization, - ) - tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) + merge_lora(model, tokenizer, cfg) return if cfg.inference: @@ -311,18 +315,7 @@ def train( model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) if cfg.adapter is not None: - # pylint: disable=duplicate-code - LOG.info("running merge of LoRA with base model") - model = model.merge_and_unload() - model.to(dtype=torch.float16) - - if cfg.local_rank == 0: - LOG.info("saving merged model") - model.save_pretrained( - str(Path(cfg.output_dir) / "merged"), - safe_serialization=safe_serialization, - ) - tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) + merge_lora(model, tokenizer, cfg) if __name__ == "__main__":