diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index bec6d8194..7eef944f3 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -62,9 +62,12 @@ def load_model( logging.info("patching with xformers attention") hijack_llama_attention() - torch_dtype = ( - torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32 - ) + if cfg.bf16: + torch_dtype = torch.bfloat16 + elif cfg.load_in_8bit or cfg.fp16: + torch_dtype = torch.float16 + else: + torch_dtype = torch.float32 try: if cfg.load_4bit: from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (