Refactor creating FusedExperts

This commit is contained in:
Casper Hansen
2024-03-15 11:59:56 +00:00
parent 3f7ed6a784
commit 1bc008e901
3 changed files with 10 additions and 21 deletions

View File

@@ -11,7 +11,14 @@ class SparseMoeBlock(nn.Module):
self.num_experts = num_experts
self.top_k = top_k
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.experts: FusedExperts = experts
self.experts = FusedExperts(
experts=experts,
input_size=ffn_dim,
hidden_size=hidden_dim,
num_experts=num_experts,
top_k=top_k,
activation=experts[0].act_fn
)
def _post_training(self, model, name):
# get original weights back: reverse the concat + stack in the fused experts

View File

@@ -720,23 +720,14 @@ def load_model(
and not cfg.adapter
and cfg.fuse_moe
):
from axolotl.monkeypatch.moe.mlp import FusedExperts
from axolotl.monkeypatch.utils import set_module_name
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
for name, module in model.named_modules():
if isinstance(module, MixtralSparseMoeBlock):
experts = FusedExperts(
experts=module.experts,
input_size=module.ffn_dim,
hidden_size=module.hidden_dim,
num_experts=module.num_experts,
top_k=module.top_k,
activation=module.experts[0].act_fn
)
smoe = SparseMoeBlock(
experts=experts,
experts=module.experts,
hidden_dim=module.hidden_dim,
ffn_dim=module.ffn_dim,
num_experts=module.num_experts,

View File

@@ -23,18 +23,9 @@ def test_fused_mixtral_moe():
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
mixtral_moe = MixtralSparseMoeBlock(config)
mixtral_moe_copy = deepcopy(mixtral_moe)
experts = FusedExperts(
experts=mixtral_moe_copy.experts,
input_size=mixtral_moe_copy.ffn_dim,
hidden_size=mixtral_moe_copy.hidden_dim,
num_experts=mixtral_moe_copy.num_experts,
top_k=mixtral_moe_copy.top_k,
activation=mixtral_moe_copy.experts[0].act_fn
)
sparse_moe = SparseMoeBlock(
experts,
experts=mixtral_moe.experts,
hidden_dim=config.hidden_size,
ffn_dim=config.intermediate_size,
num_experts=config.num_local_experts,