From 9aaa4b8ced3a84c53b9c5c8a85fbe0f97d9174ea Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 23 Aug 2023 03:55:44 -0400 Subject: [PATCH] set model merge dtype based on cfg --- scripts/finetune.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 08f4813e6..dda42e79b 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -158,7 +158,8 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b def merge_lora(model, tokenizer, cfg): LOG.info("running merge of LoRA with base model") model = model.merge_and_unload() - model.to(dtype=torch.float16) + model_dtype = torch.bfloat16 if cfg.bf16 or cfg.bfloat16 else torch.float16 + model.to(dtype=model_dtype) if cfg.hub_model_id: model.push_to_hub("hub_model_id")