Refactor names, bugfixes

This commit is contained in:
Casper Hansen
2024-03-15 12:39:11 +00:00
parent 1bc008e901
commit 26fc10df01
4 changed files with 31 additions and 23 deletions

View File

@@ -15,8 +15,8 @@ class FusedExperts(nn.Module):
def __init__(
self,
experts=None,
input_size=128,
hidden_size=512,
hidden_dim=128,
ffn_dim=512,
num_experts=8,
top_k=2,
activation=nn.SiLU(),
@@ -28,37 +28,39 @@ class FusedExperts(nn.Module):
super(FusedExperts, self).__init__()
self.num_experts = num_experts
self.input_size = input_size
self.hidden_size = hidden_size
self.experts = ParallelExperts(num_experts, input_size, 2 * hidden_size)
self.output_experts = ParallelExperts(num_experts, hidden_size, input_size)
self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim)
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim)
self.top_k = min(top_k, self.num_experts)
self.activation = activation
# parallelize all w1 and w3 computation by concat + stack
with torch.no_grad():
self.experts.weight.data = torch.stack(
torch.stack(
[
torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=1)
torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0)
for i in range(len(experts))
],
dim=0,
out=self.experts.weight.data,
)
# parallelize all w2 computation by stack
self.output_experts.weight.data = torch.stack(
torch.stack(
[expert.w2.weight for expert in experts],
dim=0,
)
out=self.output_experts.weight.data,
)
def forward(
self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
):
x_shape = x.size()
x = x.view(-1, x_shape[-1])
with torch.no_grad():
sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(
expert_idxs
selected_experts
)
padded_block_idxs, expert_offsets = ops.padded_block_indices(
sorted_expert_idxs, self.num_experts
@@ -82,7 +84,7 @@ class FusedExperts(nn.Module):
padded_block_idxs,
expert_offsets,
grouped_in=True,
gates=expert_p,
gates=routing_weights,
)
y = y.view(*x_shape[:-1], y.size(-1))
return y

View File

@@ -4,17 +4,17 @@ import torch.nn.functional as F
from axolotl.monkeypatch.moe.mlp import FusedExperts
class SparseMoeBlock(nn.Module):
def __init__(self, experts, hidden_dim, ffn_dim, num_experts, top_k):
def __init__(self, experts, gate, hidden_dim, ffn_dim, num_experts, top_k):
super().__init__()
self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim
self.num_experts = num_experts
self.top_k = top_k
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.gate = gate
self.experts = FusedExperts(
experts=experts,
input_size=ffn_dim,
hidden_size=hidden_dim,
hidden_dim=hidden_dim,
ffn_dim=ffn_dim,
num_experts=num_experts,
top_k=top_k,
activation=experts[0].act_fn

View File

@@ -728,6 +728,7 @@ def load_model(
if isinstance(module, MixtralSparseMoeBlock):
smoe = SparseMoeBlock(
experts=module.experts,
gate=module.gate,
hidden_dim=module.hidden_dim,
ffn_dim=module.ffn_dim,
num_experts=module.num_experts,

View File

@@ -10,7 +10,7 @@ def test_fused_mixtral_moe():
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.set_default_dtype(torch.float16)
torch.set_default_dtype(torch.float32)
torch.set_default_device("cuda")
# Define the configuration for the MixtralSparseMoeBlock
@@ -26,6 +26,7 @@ def test_fused_mixtral_moe():
sparse_moe = SparseMoeBlock(
experts=mixtral_moe.experts,
gate=mixtral_moe.gate,
hidden_dim=config.hidden_size,
ffn_dim=config.intermediate_size,
num_experts=config.num_local_experts,
@@ -39,14 +40,18 @@ def test_fused_mixtral_moe():
# Run the forward pass with gradients for both models
with torch.no_grad():
mixtral_output, _ = mixtral_moe(input_data)
sparse_output, _ = sparse_moe(input_data)
mixtral_output, mixtral_router_logits = mixtral_moe(input_data)
sparse_output, sparse_router_logits = sparse_moe(input_data)
# Compute the difference between the outputs and router logits
# Compute the difference between the outputs
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
print(output_diff, router_diff)
# Define the tolerance for the difference
tolerance = 0.1
tolerance = 0.05
# # Check if the difference is within the tolerance
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
assert router_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"