From d0e739da24b66a2b79bd5a2b4bd478976677c90a Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Tue, 4 Feb 2025 21:57:21 -0500 Subject: [PATCH] attempt at getting around bf16 error --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2e5bd9b74..7cabda45a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1104,7 +1104,7 @@ class ModelLoader: should_convert = ( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - ((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp) + ((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp) or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass )