diff --git a/src/axolotl/monkeypatch/moe/mlp.py b/src/axolotl/monkeypatch/moe/mlp.py index 35b1dccda..fb75f2740 100644 --- a/src/axolotl/monkeypatch/moe/mlp.py +++ b/src/axolotl/monkeypatch/moe/mlp.py @@ -15,8 +15,8 @@ class FusedExperts(nn.Module): def __init__( self, experts=None, - input_size=128, - hidden_size=512, + hidden_dim=128, + ffn_dim=512, num_experts=8, top_k=2, activation=nn.SiLU(), @@ -28,37 +28,39 @@ class FusedExperts(nn.Module): super(FusedExperts, self).__init__() self.num_experts = num_experts - self.input_size = input_size - self.hidden_size = hidden_size - self.experts = ParallelExperts(num_experts, input_size, 2 * hidden_size) - self.output_experts = ParallelExperts(num_experts, hidden_size, input_size) + self.hidden_dim = hidden_dim + self.ffn_dim = ffn_dim + self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim) + self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim) self.top_k = min(top_k, self.num_experts) self.activation = activation # parallelize all w1 and w3 computation by concat + stack with torch.no_grad(): - self.experts.weight.data = torch.stack( + torch.stack( [ - torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=1) + torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0) for i in range(len(experts)) ], dim=0, + out=self.experts.weight.data, ) # parallelize all w2 computation by stack - self.output_experts.weight.data = torch.stack( + torch.stack( [expert.w2.weight for expert in experts], dim=0, - ) + out=self.output_experts.weight.data, + ) def forward( - self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor + self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor ): x_shape = x.size() x = x.view(-1, x_shape[-1]) with torch.no_grad(): sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort( - expert_idxs + selected_experts ) padded_block_idxs, expert_offsets = ops.padded_block_indices( sorted_expert_idxs, self.num_experts @@ -82,7 +84,7 @@ class FusedExperts(nn.Module): padded_block_idxs, expert_offsets, grouped_in=True, - gates=expert_p, + gates=routing_weights, ) y = y.view(*x_shape[:-1], y.size(-1)) return y diff --git a/src/axolotl/monkeypatch/moe/moe.py b/src/axolotl/monkeypatch/moe/moe.py index 4e4ffec31..f2a506883 100644 --- a/src/axolotl/monkeypatch/moe/moe.py +++ b/src/axolotl/monkeypatch/moe/moe.py @@ -4,17 +4,17 @@ import torch.nn.functional as F from axolotl.monkeypatch.moe.mlp import FusedExperts class SparseMoeBlock(nn.Module): - def __init__(self, experts, hidden_dim, ffn_dim, num_experts, top_k): + def __init__(self, experts, gate, hidden_dim, ffn_dim, num_experts, top_k): super().__init__() self.hidden_dim = hidden_dim self.ffn_dim = ffn_dim self.num_experts = num_experts self.top_k = top_k - self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + self.gate = gate self.experts = FusedExperts( experts=experts, - input_size=ffn_dim, - hidden_size=hidden_dim, + hidden_dim=hidden_dim, + ffn_dim=ffn_dim, num_experts=num_experts, top_k=top_k, activation=experts[0].act_fn diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index a992357e9..139d8555d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -728,6 +728,7 @@ def load_model( if isinstance(module, MixtralSparseMoeBlock): smoe = SparseMoeBlock( experts=module.experts, + gate=module.gate, hidden_dim=module.hidden_dim, ffn_dim=module.ffn_dim, num_experts=module.num_experts, diff --git a/tests/monkeypatch/test_moe.py b/tests/monkeypatch/test_moe.py index 503de3153..7fbc06462 100644 --- a/tests/monkeypatch/test_moe.py +++ b/tests/monkeypatch/test_moe.py @@ -10,7 +10,7 @@ def test_fused_mixtral_moe(): torch.manual_seed(0) torch.cuda.manual_seed(0) torch.cuda.manual_seed_all(0) - torch.set_default_dtype(torch.float16) + torch.set_default_dtype(torch.float32) torch.set_default_device("cuda") # Define the configuration for the MixtralSparseMoeBlock @@ -26,6 +26,7 @@ def test_fused_mixtral_moe(): sparse_moe = SparseMoeBlock( experts=mixtral_moe.experts, + gate=mixtral_moe.gate, hidden_dim=config.hidden_size, ffn_dim=config.intermediate_size, num_experts=config.num_local_experts, @@ -39,14 +40,18 @@ def test_fused_mixtral_moe(): # Run the forward pass with gradients for both models with torch.no_grad(): - mixtral_output, _ = mixtral_moe(input_data) - sparse_output, _ = sparse_moe(input_data) + mixtral_output, mixtral_router_logits = mixtral_moe(input_data) + sparse_output, sparse_router_logits = sparse_moe(input_data) - # Compute the difference between the outputs and router logits + # Compute the difference between the outputs output_diff = torch.abs(mixtral_output - sparse_output).mean().item() + router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item() + + print(output_diff, router_diff) # Define the tolerance for the difference - tolerance = 0.1 + tolerance = 0.05 # # Check if the difference is within the tolerance - assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}" \ No newline at end of file + assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}" + assert router_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}" \ No newline at end of file