fix: torch_dtype mistral default to fp32 (#1050)

This commit is contained in:
NanoCode012
2024-01-09 21:48:15 +09:00
committed by GitHub
parent 7f381750d9
commit c3e8165f26

View File

@@ -599,7 +599,10 @@ def load_model(
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model):
if needs_fa2_dtype or (
cfg.flash_attention
and (cfg.is_llama_derived_model or cfg.is_mistral_derived_model)
):
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules():
if "norm" in name: