From c3e8165f265b7a717aa56ed0795eed7eb35e2cff Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 9 Jan 2024 21:48:15 +0900 Subject: [PATCH] fix: torch_dtype mistral default to fp32 (#1050) --- src/axolotl/utils/models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 40b78a969..98d1e7c48 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -599,7 +599,10 @@ def load_model( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model): + if needs_fa2_dtype or ( + cfg.flash_attention + and (cfg.is_llama_derived_model or cfg.is_mistral_derived_model) + ): LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) for name, module in model.named_modules(): if "norm" in name: