set model merge dtype based on cfg
This commit is contained in:
@@ -158,7 +158,8 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
|
||||
def merge_lora(model, tokenizer, cfg):
|
||||
LOG.info("running merge of LoRA with base model")
|
||||
model = model.merge_and_unload()
|
||||
model.to(dtype=torch.float16)
|
||||
model_dtype = torch.bfloat16 if cfg.bf16 or cfg.bfloat16 else torch.float16
|
||||
model.to(dtype=model_dtype)
|
||||
if cfg.hub_model_id:
|
||||
model.push_to_hub("hub_model_id")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user