add logging and make sure model unloads to float16

This commit is contained in:
Wing Lian
2023-05-26 00:09:55 -04:00
parent a4f12415a0
commit a5bf838685
2 changed files with 7 additions and 0 deletions

View File

@@ -176,6 +176,7 @@ def train(
if "merge_lora" in kwargs and cfg.adapter is not None:
logging.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
logging.info("saving merged model")