From 78b9efb7f4ae508ccaa2954e6b326d7bb938945c Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Thu, 3 Aug 2023 19:19:39 +0000 Subject: [PATCH] scope flash-attn+qlora fix correctly, scope to llama, add comment --- src/axolotl/utils/models.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 23d7716a0..0dcb4cd77 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -333,13 +333,15 @@ def load_model( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) - if cfg.flash_attention: - for name, module in model.named_modules(): - if "norm" in name: - module.to(torch_dtype) - if "lm_head" in name or "embed_tokens" in name: - if hasattr(module, "weight"): + # LlamaRMSNorm layers are in fp32 after kit call, so we need to + # convert them back to fp16/bf16 for flash-attn compatibility. + if cfg.flash_attention and cfg.is_llama_derived_model: + for name, module in model.named_modules(): + if "norm" in name: module.to(torch_dtype) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module.to(torch_dtype) model, lora_config = load_adapter(model, cfg, adapter)