diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2e5bd9b74..7cabda45a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1104,7 +1104,7 @@ class ModelLoader: should_convert = ( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - ((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp) + ((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp) or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass )