diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ed917d963..4575f5966 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -368,7 +368,7 @@ def load_model( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - if (fix_dtype or not cfg.adapter) and ( + if fix_dtype and ( cfg.flash_attention and cfg.is_llama_derived_model ): for name, module in model.named_modules():