diff --git a/src/axolotl/monkeypatch/moe/moe.py b/src/axolotl/monkeypatch/moe/moe.py index 0f68f0c43..4e4ffec31 100644 --- a/src/axolotl/monkeypatch/moe/moe.py +++ b/src/axolotl/monkeypatch/moe/moe.py @@ -11,7 +11,14 @@ class SparseMoeBlock(nn.Module): self.num_experts = num_experts self.top_k = top_k self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.experts: FusedExperts = experts + self.experts = FusedExperts( + experts=experts, + input_size=ffn_dim, + hidden_size=hidden_dim, + num_experts=num_experts, + top_k=top_k, + activation=experts[0].act_fn + ) def _post_training(self, model, name): # get original weights back: reverse the concat + stack in the fused experts diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6128269b2..a992357e9 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -720,23 +720,14 @@ def load_model( and not cfg.adapter and cfg.fuse_moe ): - from axolotl.monkeypatch.moe.mlp import FusedExperts from axolotl.monkeypatch.utils import set_module_name from axolotl.monkeypatch.moe.moe import SparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock for name, module in model.named_modules(): if isinstance(module, MixtralSparseMoeBlock): - experts = FusedExperts( - experts=module.experts, - input_size=module.ffn_dim, - hidden_size=module.hidden_dim, - num_experts=module.num_experts, - top_k=module.top_k, - activation=module.experts[0].act_fn - ) smoe = SparseMoeBlock( - experts=experts, + experts=module.experts, hidden_dim=module.hidden_dim, ffn_dim=module.ffn_dim, num_experts=module.num_experts, diff --git a/tests/monkeypatch/test_moe.py b/tests/monkeypatch/test_moe.py index 2e4f7271b..503de3153 100644 --- a/tests/monkeypatch/test_moe.py +++ b/tests/monkeypatch/test_moe.py @@ -23,18 +23,9 @@ def test_fused_mixtral_moe(): # Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration mixtral_moe = MixtralSparseMoeBlock(config) - mixtral_moe_copy = deepcopy(mixtral_moe) - experts = FusedExperts( - experts=mixtral_moe_copy.experts, - input_size=mixtral_moe_copy.ffn_dim, - hidden_size=mixtral_moe_copy.hidden_dim, - num_experts=mixtral_moe_copy.num_experts, - top_k=mixtral_moe_copy.top_k, - activation=mixtral_moe_copy.experts[0].act_fn - ) sparse_moe = SparseMoeBlock( - experts, + experts=mixtral_moe.experts, hidden_dim=config.hidden_size, ffn_dim=config.intermediate_size, num_experts=config.num_local_experts,