Only fuse if flash_attn_fuse_mlp is True

This commit is contained in:
Casper
2023-12-10 19:17:12 +01:00
parent 279a1401b5
commit a58a9e5f6c

View File

@@ -385,6 +385,7 @@ def load_model(
**model_kwargs, **model_kwargs,
) )
if cfg.flash_attn_fuse_mlp:
LOG.info("Mixtral MoE: Replacing experts with SwiGLU") LOG.info("Mixtral MoE: Replacing experts with SwiGLU")
replace_mixtral_mlp_with_swiglu(model) replace_mixtral_mlp_with_swiglu(model)