Only fuse if flash_attn_fuse_mlp is True

This commit is contained in:
Casper
2023-12-10 19:17:12 +01:00
parent 279a1401b5
commit a58a9e5f6c

View File

@@ -385,8 +385,9 @@ def load_model(
**model_kwargs,
)
LOG.info("Mixtral MoE: Replacing experts with SwiGLU")
replace_mixtral_mlp_with_swiglu(model)
if cfg.flash_attn_fuse_mlp:
LOG.info("Mixtral MoE: Replacing experts with SwiGLU")
replace_mixtral_mlp_with_swiglu(model)
elif model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work