Refactor creating FusedExperts
This commit is contained in:
@@ -11,7 +11,14 @@ class SparseMoeBlock(nn.Module):
|
|||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
||||||
self.experts: FusedExperts = experts
|
self.experts = FusedExperts(
|
||||||
|
experts=experts,
|
||||||
|
input_size=ffn_dim,
|
||||||
|
hidden_size=hidden_dim,
|
||||||
|
num_experts=num_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
activation=experts[0].act_fn
|
||||||
|
)
|
||||||
|
|
||||||
def _post_training(self, model, name):
|
def _post_training(self, model, name):
|
||||||
# get original weights back: reverse the concat + stack in the fused experts
|
# get original weights back: reverse the concat + stack in the fused experts
|
||||||
|
|||||||
@@ -720,23 +720,14 @@ def load_model(
|
|||||||
and not cfg.adapter
|
and not cfg.adapter
|
||||||
and cfg.fuse_moe
|
and cfg.fuse_moe
|
||||||
):
|
):
|
||||||
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
|
||||||
from axolotl.monkeypatch.utils import set_module_name
|
from axolotl.monkeypatch.utils import set_module_name
|
||||||
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, MixtralSparseMoeBlock):
|
if isinstance(module, MixtralSparseMoeBlock):
|
||||||
experts = FusedExperts(
|
|
||||||
experts=module.experts,
|
|
||||||
input_size=module.ffn_dim,
|
|
||||||
hidden_size=module.hidden_dim,
|
|
||||||
num_experts=module.num_experts,
|
|
||||||
top_k=module.top_k,
|
|
||||||
activation=module.experts[0].act_fn
|
|
||||||
)
|
|
||||||
smoe = SparseMoeBlock(
|
smoe = SparseMoeBlock(
|
||||||
experts=experts,
|
experts=module.experts,
|
||||||
hidden_dim=module.hidden_dim,
|
hidden_dim=module.hidden_dim,
|
||||||
ffn_dim=module.ffn_dim,
|
ffn_dim=module.ffn_dim,
|
||||||
num_experts=module.num_experts,
|
num_experts=module.num_experts,
|
||||||
|
|||||||
@@ -23,18 +23,9 @@ def test_fused_mixtral_moe():
|
|||||||
|
|
||||||
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
|
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
|
||||||
mixtral_moe = MixtralSparseMoeBlock(config)
|
mixtral_moe = MixtralSparseMoeBlock(config)
|
||||||
mixtral_moe_copy = deepcopy(mixtral_moe)
|
|
||||||
|
|
||||||
experts = FusedExperts(
|
|
||||||
experts=mixtral_moe_copy.experts,
|
|
||||||
input_size=mixtral_moe_copy.ffn_dim,
|
|
||||||
hidden_size=mixtral_moe_copy.hidden_dim,
|
|
||||||
num_experts=mixtral_moe_copy.num_experts,
|
|
||||||
top_k=mixtral_moe_copy.top_k,
|
|
||||||
activation=mixtral_moe_copy.experts[0].act_fn
|
|
||||||
)
|
|
||||||
sparse_moe = SparseMoeBlock(
|
sparse_moe = SparseMoeBlock(
|
||||||
experts,
|
experts=mixtral_moe.experts,
|
||||||
hidden_dim=config.hidden_size,
|
hidden_dim=config.hidden_size,
|
||||||
ffn_dim=config.intermediate_size,
|
ffn_dim=config.intermediate_size,
|
||||||
num_experts=config.num_local_experts,
|
num_experts=config.num_local_experts,
|
||||||
|
|||||||
Reference in New Issue
Block a user