From 1991946c5a5b57bc89aaec8167066b334543aba6 Mon Sep 17 00:00:00 2001 From: Maxime <672982+maximegmd@users.noreply.github.com> Date: Fri, 1 Sep 2023 16:11:45 +0200 Subject: [PATCH] fix: bad dtype for full finetune (#504) * fix: bad dtype for full finetune * Update src/axolotl/utils/models.py Co-authored-by: Wing Lian * Update models.py --------- Co-authored-by: Wing Lian --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4b9c79d84..9f0795af7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -371,7 +371,7 @@ def load_model( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model): + if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model): LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) for name, module in model.named_modules(): if "norm" in name: