Refactor names, bugfixes
This commit is contained in:
@@ -15,8 +15,8 @@ class FusedExperts(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
experts=None,
|
||||
input_size=128,
|
||||
hidden_size=512,
|
||||
hidden_dim=128,
|
||||
ffn_dim=512,
|
||||
num_experts=8,
|
||||
top_k=2,
|
||||
activation=nn.SiLU(),
|
||||
@@ -28,37 +28,39 @@ class FusedExperts(nn.Module):
|
||||
super(FusedExperts, self).__init__()
|
||||
|
||||
self.num_experts = num_experts
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.experts = ParallelExperts(num_experts, input_size, 2 * hidden_size)
|
||||
self.output_experts = ParallelExperts(num_experts, hidden_size, input_size)
|
||||
self.hidden_dim = hidden_dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim)
|
||||
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim)
|
||||
self.top_k = min(top_k, self.num_experts)
|
||||
self.activation = activation
|
||||
|
||||
# parallelize all w1 and w3 computation by concat + stack
|
||||
with torch.no_grad():
|
||||
self.experts.weight.data = torch.stack(
|
||||
torch.stack(
|
||||
[
|
||||
torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=1)
|
||||
torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0)
|
||||
for i in range(len(experts))
|
||||
],
|
||||
dim=0,
|
||||
out=self.experts.weight.data,
|
||||
)
|
||||
|
||||
# parallelize all w2 computation by stack
|
||||
self.output_experts.weight.data = torch.stack(
|
||||
torch.stack(
|
||||
[expert.w2.weight for expert in experts],
|
||||
dim=0,
|
||||
)
|
||||
out=self.output_experts.weight.data,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor
|
||||
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
|
||||
):
|
||||
x_shape = x.size()
|
||||
x = x.view(-1, x_shape[-1])
|
||||
with torch.no_grad():
|
||||
sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(
|
||||
expert_idxs
|
||||
selected_experts
|
||||
)
|
||||
padded_block_idxs, expert_offsets = ops.padded_block_indices(
|
||||
sorted_expert_idxs, self.num_experts
|
||||
@@ -82,7 +84,7 @@ class FusedExperts(nn.Module):
|
||||
padded_block_idxs,
|
||||
expert_offsets,
|
||||
grouped_in=True,
|
||||
gates=expert_p,
|
||||
gates=routing_weights,
|
||||
)
|
||||
y = y.view(*x_shape[:-1], y.size(-1))
|
||||
return y
|
||||
|
||||
@@ -4,17 +4,17 @@ import torch.nn.functional as F
|
||||
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
||||
|
||||
class SparseMoeBlock(nn.Module):
|
||||
def __init__(self, experts, hidden_dim, ffn_dim, num_experts, top_k):
|
||||
def __init__(self, experts, gate, hidden_dim, ffn_dim, num_experts, top_k):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
||||
self.gate = gate
|
||||
self.experts = FusedExperts(
|
||||
experts=experts,
|
||||
input_size=ffn_dim,
|
||||
hidden_size=hidden_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
ffn_dim=ffn_dim,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
activation=experts[0].act_fn
|
||||
|
||||
@@ -728,6 +728,7 @@ def load_model(
|
||||
if isinstance(module, MixtralSparseMoeBlock):
|
||||
smoe = SparseMoeBlock(
|
||||
experts=module.experts,
|
||||
gate=module.gate,
|
||||
hidden_dim=module.hidden_dim,
|
||||
ffn_dim=module.ffn_dim,
|
||||
num_experts=module.num_experts,
|
||||
|
||||
@@ -10,7 +10,7 @@ def test_fused_mixtral_moe():
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
torch.cuda.manual_seed_all(0)
|
||||
torch.set_default_dtype(torch.float16)
|
||||
torch.set_default_dtype(torch.float32)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
# Define the configuration for the MixtralSparseMoeBlock
|
||||
@@ -26,6 +26,7 @@ def test_fused_mixtral_moe():
|
||||
|
||||
sparse_moe = SparseMoeBlock(
|
||||
experts=mixtral_moe.experts,
|
||||
gate=mixtral_moe.gate,
|
||||
hidden_dim=config.hidden_size,
|
||||
ffn_dim=config.intermediate_size,
|
||||
num_experts=config.num_local_experts,
|
||||
@@ -39,14 +40,18 @@ def test_fused_mixtral_moe():
|
||||
|
||||
# Run the forward pass with gradients for both models
|
||||
with torch.no_grad():
|
||||
mixtral_output, _ = mixtral_moe(input_data)
|
||||
sparse_output, _ = sparse_moe(input_data)
|
||||
mixtral_output, mixtral_router_logits = mixtral_moe(input_data)
|
||||
sparse_output, sparse_router_logits = sparse_moe(input_data)
|
||||
|
||||
# Compute the difference between the outputs and router logits
|
||||
# Compute the difference between the outputs
|
||||
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
|
||||
router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
|
||||
|
||||
print(output_diff, router_diff)
|
||||
|
||||
# Define the tolerance for the difference
|
||||
tolerance = 0.1
|
||||
tolerance = 0.05
|
||||
|
||||
# # Check if the difference is within the tolerance
|
||||
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||
assert router_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||
Reference in New Issue
Block a user