fix torch_dtype for model load
This commit is contained in:
@@ -62,9 +62,12 @@ def load_model(
|
||||
logging.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
|
||||
torch_dtype = (
|
||||
torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
|
||||
)
|
||||
if cfg.bf16:
|
||||
torch_dtype = torch.bfloat16
|
||||
elif cfg.load_in_8bit or cfg.fp16:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
try:
|
||||
if cfg.load_4bit:
|
||||
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
||||
|
||||
Reference in New Issue
Block a user