merge lora on train completion

This commit is contained in:
Wing Lian
2023-08-22 23:15:09 -04:00
parent d5dcf9c350
commit f20c8deff1

View File

@@ -310,6 +310,20 @@ def train(
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if cfg.adapter is not None:
# pylint: disable=duplicate-code
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
if __name__ == "__main__":
fire.Fire(train)