From 3a011ea1ef4ddee446e22849651783dd758dfda6 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Sun, 27 Aug 2023 20:09:26 +0000 Subject: [PATCH] fix condition and add logging --- src/axolotl/utils/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index dd75106ec..c2fbc19e3 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -355,7 +355,7 @@ def load_model( if hasattr(module, "weight"): module.to(torch.float32) - needs_fa2_dtype = not cfg.adapter + needs_fa2_dtype = cfg.adapter is not None if not cfg.gptq and ( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) @@ -369,6 +369,7 @@ def load_model( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model): + LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) for name, module in model.named_modules(): if "norm" in name: module.to(cfg.torch_dtype)