Refactor creating FusedExperts
This commit is contained in:
@@ -23,18 +23,9 @@ def test_fused_mixtral_moe():
|
||||
|
||||
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
|
||||
mixtral_moe = MixtralSparseMoeBlock(config)
|
||||
mixtral_moe_copy = deepcopy(mixtral_moe)
|
||||
|
||||
experts = FusedExperts(
|
||||
experts=mixtral_moe_copy.experts,
|
||||
input_size=mixtral_moe_copy.ffn_dim,
|
||||
hidden_size=mixtral_moe_copy.hidden_dim,
|
||||
num_experts=mixtral_moe_copy.num_experts,
|
||||
top_k=mixtral_moe_copy.top_k,
|
||||
activation=mixtral_moe_copy.experts[0].act_fn
|
||||
)
|
||||
sparse_moe = SparseMoeBlock(
|
||||
experts,
|
||||
experts=mixtral_moe.experts,
|
||||
hidden_dim=config.hidden_size,
|
||||
ffn_dim=config.intermediate_size,
|
||||
num_experts=config.num_local_experts,
|
||||
|
||||
Reference in New Issue
Block a user