Refactor creating FusedExperts

This commit is contained in:
Casper Hansen
2024-03-15 11:59:56 +00:00
parent 3f7ed6a784
commit 1bc008e901
3 changed files with 10 additions and 21 deletions

View File

@@ -23,18 +23,9 @@ def test_fused_mixtral_moe():
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
mixtral_moe = MixtralSparseMoeBlock(config)
mixtral_moe_copy = deepcopy(mixtral_moe)
experts = FusedExperts(
experts=mixtral_moe_copy.experts,
input_size=mixtral_moe_copy.ffn_dim,
hidden_size=mixtral_moe_copy.hidden_dim,
num_experts=mixtral_moe_copy.num_experts,
top_k=mixtral_moe_copy.top_k,
activation=mixtral_moe_copy.experts[0].act_fn
)
sparse_moe = SparseMoeBlock(
experts,
experts=mixtral_moe.experts,
hidden_dim=config.hidden_size,
ffn_dim=config.intermediate_size,
num_experts=config.num_local_experts,