Refactor creating FusedExperts
This commit is contained in:
@@ -11,7 +11,14 @@ class SparseMoeBlock(nn.Module):
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
||||
self.experts: FusedExperts = experts
|
||||
self.experts = FusedExperts(
|
||||
experts=experts,
|
||||
input_size=ffn_dim,
|
||||
hidden_size=hidden_dim,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
activation=experts[0].act_fn
|
||||
)
|
||||
|
||||
def _post_training(self, model, name):
|
||||
# get original weights back: reverse the concat + stack in the fused experts
|
||||
|
||||
@@ -720,23 +720,14 @@ def load_model(
|
||||
and not cfg.adapter
|
||||
and cfg.fuse_moe
|
||||
):
|
||||
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
||||
from axolotl.monkeypatch.utils import set_module_name
|
||||
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, MixtralSparseMoeBlock):
|
||||
experts = FusedExperts(
|
||||
experts=module.experts,
|
||||
input_size=module.ffn_dim,
|
||||
hidden_size=module.hidden_dim,
|
||||
num_experts=module.num_experts,
|
||||
top_k=module.top_k,
|
||||
activation=module.experts[0].act_fn
|
||||
)
|
||||
smoe = SparseMoeBlock(
|
||||
experts=experts,
|
||||
experts=module.experts,
|
||||
hidden_dim=module.hidden_dim,
|
||||
ffn_dim=module.ffn_dim,
|
||||
num_experts=module.num_experts,
|
||||
|
||||
@@ -23,18 +23,9 @@ def test_fused_mixtral_moe():
|
||||
|
||||
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
|
||||
mixtral_moe = MixtralSparseMoeBlock(config)
|
||||
mixtral_moe_copy = deepcopy(mixtral_moe)
|
||||
|
||||
experts = FusedExperts(
|
||||
experts=mixtral_moe_copy.experts,
|
||||
input_size=mixtral_moe_copy.ffn_dim,
|
||||
hidden_size=mixtral_moe_copy.hidden_dim,
|
||||
num_experts=mixtral_moe_copy.num_experts,
|
||||
top_k=mixtral_moe_copy.top_k,
|
||||
activation=mixtral_moe_copy.experts[0].act_fn
|
||||
)
|
||||
sparse_moe = SparseMoeBlock(
|
||||
experts,
|
||||
experts=mixtral_moe.experts,
|
||||
hidden_dim=config.hidden_size,
|
||||
ffn_dim=config.intermediate_size,
|
||||
num_experts=config.num_local_experts,
|
||||
|
||||
Reference in New Issue
Block a user