attempt at getting around bf16 error
This commit is contained in:
@@ -1104,7 +1104,7 @@ class ModelLoader:
|
||||
should_convert = (
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
|
||||
((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp)
|
||||
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user