fix torch_dtype for model load
This commit is contained in:
@@ -62,9 +62,12 @@ def load_model(
|
|||||||
logging.info("patching with xformers attention")
|
logging.info("patching with xformers attention")
|
||||||
hijack_llama_attention()
|
hijack_llama_attention()
|
||||||
|
|
||||||
torch_dtype = (
|
if cfg.bf16:
|
||||||
torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
|
torch_dtype = torch.bfloat16
|
||||||
)
|
elif cfg.load_in_8bit or cfg.fp16:
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
else:
|
||||||
|
torch_dtype = torch.float32
|
||||||
try:
|
try:
|
||||||
if cfg.load_4bit:
|
if cfg.load_4bit:
|
||||||
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
||||||
|
|||||||
Reference in New Issue
Block a user