diff --git a/scripts/finetune.py b/scripts/finetune.py index 08f4813e6..dda42e79b 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -158,7 +158,8 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b def merge_lora(model, tokenizer, cfg): LOG.info("running merge of LoRA with base model") model = model.merge_and_unload() - model.to(dtype=torch.float16) + model_dtype = torch.bfloat16 if cfg.bf16 or cfg.bfloat16 else torch.float16 + model.to(dtype=model_dtype) if cfg.hub_model_id: model.push_to_hub("hub_model_id")