Refactor names, bugfixes
This commit is contained in:
@@ -15,8 +15,8 @@ class FusedExperts(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experts=None,
|
experts=None,
|
||||||
input_size=128,
|
hidden_dim=128,
|
||||||
hidden_size=512,
|
ffn_dim=512,
|
||||||
num_experts=8,
|
num_experts=8,
|
||||||
top_k=2,
|
top_k=2,
|
||||||
activation=nn.SiLU(),
|
activation=nn.SiLU(),
|
||||||
@@ -28,37 +28,39 @@ class FusedExperts(nn.Module):
|
|||||||
super(FusedExperts, self).__init__()
|
super(FusedExperts, self).__init__()
|
||||||
|
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.input_size = input_size
|
self.hidden_dim = hidden_dim
|
||||||
self.hidden_size = hidden_size
|
self.ffn_dim = ffn_dim
|
||||||
self.experts = ParallelExperts(num_experts, input_size, 2 * hidden_size)
|
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim)
|
||||||
self.output_experts = ParallelExperts(num_experts, hidden_size, input_size)
|
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim)
|
||||||
self.top_k = min(top_k, self.num_experts)
|
self.top_k = min(top_k, self.num_experts)
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
|
||||||
# parallelize all w1 and w3 computation by concat + stack
|
# parallelize all w1 and w3 computation by concat + stack
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.experts.weight.data = torch.stack(
|
torch.stack(
|
||||||
[
|
[
|
||||||
torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=1)
|
torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0)
|
||||||
for i in range(len(experts))
|
for i in range(len(experts))
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
|
out=self.experts.weight.data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# parallelize all w2 computation by stack
|
# parallelize all w2 computation by stack
|
||||||
self.output_experts.weight.data = torch.stack(
|
torch.stack(
|
||||||
[expert.w2.weight for expert in experts],
|
[expert.w2.weight for expert in experts],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
out=self.output_experts.weight.data,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor
|
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
|
||||||
):
|
):
|
||||||
x_shape = x.size()
|
x_shape = x.size()
|
||||||
x = x.view(-1, x_shape[-1])
|
x = x.view(-1, x_shape[-1])
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(
|
sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(
|
||||||
expert_idxs
|
selected_experts
|
||||||
)
|
)
|
||||||
padded_block_idxs, expert_offsets = ops.padded_block_indices(
|
padded_block_idxs, expert_offsets = ops.padded_block_indices(
|
||||||
sorted_expert_idxs, self.num_experts
|
sorted_expert_idxs, self.num_experts
|
||||||
@@ -82,7 +84,7 @@ class FusedExperts(nn.Module):
|
|||||||
padded_block_idxs,
|
padded_block_idxs,
|
||||||
expert_offsets,
|
expert_offsets,
|
||||||
grouped_in=True,
|
grouped_in=True,
|
||||||
gates=expert_p,
|
gates=routing_weights,
|
||||||
)
|
)
|
||||||
y = y.view(*x_shape[:-1], y.size(-1))
|
y = y.view(*x_shape[:-1], y.size(-1))
|
||||||
return y
|
return y
|
||||||
|
|||||||
@@ -4,17 +4,17 @@ import torch.nn.functional as F
|
|||||||
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
||||||
|
|
||||||
class SparseMoeBlock(nn.Module):
|
class SparseMoeBlock(nn.Module):
|
||||||
def __init__(self, experts, hidden_dim, ffn_dim, num_experts, top_k):
|
def __init__(self, experts, gate, hidden_dim, ffn_dim, num_experts, top_k):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_dim = hidden_dim
|
self.hidden_dim = hidden_dim
|
||||||
self.ffn_dim = ffn_dim
|
self.ffn_dim = ffn_dim
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
self.gate = gate
|
||||||
self.experts = FusedExperts(
|
self.experts = FusedExperts(
|
||||||
experts=experts,
|
experts=experts,
|
||||||
input_size=ffn_dim,
|
hidden_dim=hidden_dim,
|
||||||
hidden_size=hidden_dim,
|
ffn_dim=ffn_dim,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
activation=experts[0].act_fn
|
activation=experts[0].act_fn
|
||||||
|
|||||||
@@ -728,6 +728,7 @@ def load_model(
|
|||||||
if isinstance(module, MixtralSparseMoeBlock):
|
if isinstance(module, MixtralSparseMoeBlock):
|
||||||
smoe = SparseMoeBlock(
|
smoe = SparseMoeBlock(
|
||||||
experts=module.experts,
|
experts=module.experts,
|
||||||
|
gate=module.gate,
|
||||||
hidden_dim=module.hidden_dim,
|
hidden_dim=module.hidden_dim,
|
||||||
ffn_dim=module.ffn_dim,
|
ffn_dim=module.ffn_dim,
|
||||||
num_experts=module.num_experts,
|
num_experts=module.num_experts,
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ def test_fused_mixtral_moe():
|
|||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
torch.cuda.manual_seed(0)
|
torch.cuda.manual_seed(0)
|
||||||
torch.cuda.manual_seed_all(0)
|
torch.cuda.manual_seed_all(0)
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float32)
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
# Define the configuration for the MixtralSparseMoeBlock
|
# Define the configuration for the MixtralSparseMoeBlock
|
||||||
@@ -26,6 +26,7 @@ def test_fused_mixtral_moe():
|
|||||||
|
|
||||||
sparse_moe = SparseMoeBlock(
|
sparse_moe = SparseMoeBlock(
|
||||||
experts=mixtral_moe.experts,
|
experts=mixtral_moe.experts,
|
||||||
|
gate=mixtral_moe.gate,
|
||||||
hidden_dim=config.hidden_size,
|
hidden_dim=config.hidden_size,
|
||||||
ffn_dim=config.intermediate_size,
|
ffn_dim=config.intermediate_size,
|
||||||
num_experts=config.num_local_experts,
|
num_experts=config.num_local_experts,
|
||||||
@@ -39,14 +40,18 @@ def test_fused_mixtral_moe():
|
|||||||
|
|
||||||
# Run the forward pass with gradients for both models
|
# Run the forward pass with gradients for both models
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
mixtral_output, _ = mixtral_moe(input_data)
|
mixtral_output, mixtral_router_logits = mixtral_moe(input_data)
|
||||||
sparse_output, _ = sparse_moe(input_data)
|
sparse_output, sparse_router_logits = sparse_moe(input_data)
|
||||||
|
|
||||||
# Compute the difference between the outputs and router logits
|
# Compute the difference between the outputs
|
||||||
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
|
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
|
||||||
|
router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
|
||||||
|
|
||||||
|
print(output_diff, router_diff)
|
||||||
|
|
||||||
# Define the tolerance for the difference
|
# Define the tolerance for the difference
|
||||||
tolerance = 0.1
|
tolerance = 0.05
|
||||||
|
|
||||||
# # Check if the difference is within the tolerance
|
# # Check if the difference is within the tolerance
|
||||||
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||||
|
assert router_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||||
Reference in New Issue
Block a user