fix torch_dtype for model load

This commit is contained in:
Wing Lian
2023-05-14 08:44:22 -04:00
parent 0d28df0fd2
commit aef00b6c13

View File

@@ -62,9 +62,12 @@ def load_model(
logging.info("patching with xformers attention")
hijack_llama_attention()
torch_dtype = (
torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
)
if cfg.bf16:
torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
try:
if cfg.load_4bit:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (