diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 23d7716a0..0dcb4cd77 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -333,13 +333,15 @@ def load_model( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) - if cfg.flash_attention: - for name, module in model.named_modules(): - if "norm" in name: - module.to(torch_dtype) - if "lm_head" in name or "embed_tokens" in name: - if hasattr(module, "weight"): + # LlamaRMSNorm layers are in fp32 after kit call, so we need to + # convert them back to fp16/bf16 for flash-attn compatibility. + if cfg.flash_attention and cfg.is_llama_derived_model: + for name, module in model.named_modules(): + if "norm" in name: module.to(torch_dtype) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module.to(torch_dtype) model, lora_config = load_adapter(model, cfg, adapter)