diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 40b78a969..98d1e7c48 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -599,7 +599,10 @@ def load_model( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model): + if needs_fa2_dtype or ( + cfg.flash_attention + and (cfg.is_llama_derived_model or cfg.is_mistral_derived_model) + ): LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) for name, module in model.named_modules(): if "norm" in name: