attempt at getting around bf16 error

This commit is contained in:
Sunny Liu
2025-02-04 21:57:21 -05:00
parent 3f6be519d5
commit d0e739da24

View File

@@ -1104,7 +1104,7 @@ class ModelLoader:
should_convert = (
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp)
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
)