From aef00b6c13054515d37dbc9dc6d41a70245ba1ae Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 14 May 2023 08:44:22 -0400 Subject: [PATCH] fix torch_dtype for model load --- src/axolotl/utils/models.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index bec6d8194..7eef944f3 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -62,9 +62,12 @@ def load_model( logging.info("patching with xformers attention") hijack_llama_attention() - torch_dtype = ( - torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32 - ) + if cfg.bf16: + torch_dtype = torch.bfloat16 + elif cfg.load_in_8bit or cfg.fp16: + torch_dtype = torch.float16 + else: + torch_dtype = torch.float32 try: if cfg.load_4bit: from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (