set model merge dtype based on cfg
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled

This commit is contained in:
Wing Lian
2023-08-23 03:55:44 -04:00
parent 8be7da8999
commit 9aaa4b8ced

View File

@@ -158,7 +158,8 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
def merge_lora(model, tokenizer, cfg):
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
model_dtype = torch.bfloat16 if cfg.bf16 or cfg.bfloat16 else torch.float16
model.to(dtype=model_dtype)
if cfg.hub_model_id:
model.push_to_hub("hub_model_id")