Refactor names, bugfixes

This commit is contained in:
Casper Hansen
2024-03-15 12:39:11 +00:00
parent 1bc008e901
commit 26fc10df01
4 changed files with 31 additions and 23 deletions

View File

@@ -15,8 +15,8 @@ class FusedExperts(nn.Module):
def __init__( def __init__(
self, self,
experts=None, experts=None,
input_size=128, hidden_dim=128,
hidden_size=512, ffn_dim=512,
num_experts=8, num_experts=8,
top_k=2, top_k=2,
activation=nn.SiLU(), activation=nn.SiLU(),
@@ -28,37 +28,39 @@ class FusedExperts(nn.Module):
super(FusedExperts, self).__init__() super(FusedExperts, self).__init__()
self.num_experts = num_experts self.num_experts = num_experts
self.input_size = input_size self.hidden_dim = hidden_dim
self.hidden_size = hidden_size self.ffn_dim = ffn_dim
self.experts = ParallelExperts(num_experts, input_size, 2 * hidden_size) self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim)
self.output_experts = ParallelExperts(num_experts, hidden_size, input_size) self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim)
self.top_k = min(top_k, self.num_experts) self.top_k = min(top_k, self.num_experts)
self.activation = activation self.activation = activation
# parallelize all w1 and w3 computation by concat + stack # parallelize all w1 and w3 computation by concat + stack
with torch.no_grad(): with torch.no_grad():
self.experts.weight.data = torch.stack( torch.stack(
[ [
torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=1) torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0)
for i in range(len(experts)) for i in range(len(experts))
], ],
dim=0, dim=0,
out=self.experts.weight.data,
) )
# parallelize all w2 computation by stack # parallelize all w2 computation by stack
self.output_experts.weight.data = torch.stack( torch.stack(
[expert.w2.weight for expert in experts], [expert.w2.weight for expert in experts],
dim=0, dim=0,
) out=self.output_experts.weight.data,
)
def forward( def forward(
self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
): ):
x_shape = x.size() x_shape = x.size()
x = x.view(-1, x_shape[-1]) x = x.view(-1, x_shape[-1])
with torch.no_grad(): with torch.no_grad():
sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort( sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(
expert_idxs selected_experts
) )
padded_block_idxs, expert_offsets = ops.padded_block_indices( padded_block_idxs, expert_offsets = ops.padded_block_indices(
sorted_expert_idxs, self.num_experts sorted_expert_idxs, self.num_experts
@@ -82,7 +84,7 @@ class FusedExperts(nn.Module):
padded_block_idxs, padded_block_idxs,
expert_offsets, expert_offsets,
grouped_in=True, grouped_in=True,
gates=expert_p, gates=routing_weights,
) )
y = y.view(*x_shape[:-1], y.size(-1)) y = y.view(*x_shape[:-1], y.size(-1))
return y return y

View File

@@ -4,17 +4,17 @@ import torch.nn.functional as F
from axolotl.monkeypatch.moe.mlp import FusedExperts from axolotl.monkeypatch.moe.mlp import FusedExperts
class SparseMoeBlock(nn.Module): class SparseMoeBlock(nn.Module):
def __init__(self, experts, hidden_dim, ffn_dim, num_experts, top_k): def __init__(self, experts, gate, hidden_dim, ffn_dim, num_experts, top_k):
super().__init__() super().__init__()
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim self.ffn_dim = ffn_dim
self.num_experts = num_experts self.num_experts = num_experts
self.top_k = top_k self.top_k = top_k
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.gate = gate
self.experts = FusedExperts( self.experts = FusedExperts(
experts=experts, experts=experts,
input_size=ffn_dim, hidden_dim=hidden_dim,
hidden_size=hidden_dim, ffn_dim=ffn_dim,
num_experts=num_experts, num_experts=num_experts,
top_k=top_k, top_k=top_k,
activation=experts[0].act_fn activation=experts[0].act_fn

View File

@@ -728,6 +728,7 @@ def load_model(
if isinstance(module, MixtralSparseMoeBlock): if isinstance(module, MixtralSparseMoeBlock):
smoe = SparseMoeBlock( smoe = SparseMoeBlock(
experts=module.experts, experts=module.experts,
gate=module.gate,
hidden_dim=module.hidden_dim, hidden_dim=module.hidden_dim,
ffn_dim=module.ffn_dim, ffn_dim=module.ffn_dim,
num_experts=module.num_experts, num_experts=module.num_experts,

View File

@@ -10,7 +10,7 @@ def test_fused_mixtral_moe():
torch.manual_seed(0) torch.manual_seed(0)
torch.cuda.manual_seed(0) torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float32)
torch.set_default_device("cuda") torch.set_default_device("cuda")
# Define the configuration for the MixtralSparseMoeBlock # Define the configuration for the MixtralSparseMoeBlock
@@ -26,6 +26,7 @@ def test_fused_mixtral_moe():
sparse_moe = SparseMoeBlock( sparse_moe = SparseMoeBlock(
experts=mixtral_moe.experts, experts=mixtral_moe.experts,
gate=mixtral_moe.gate,
hidden_dim=config.hidden_size, hidden_dim=config.hidden_size,
ffn_dim=config.intermediate_size, ffn_dim=config.intermediate_size,
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
@@ -39,14 +40,18 @@ def test_fused_mixtral_moe():
# Run the forward pass with gradients for both models # Run the forward pass with gradients for both models
with torch.no_grad(): with torch.no_grad():
mixtral_output, _ = mixtral_moe(input_data) mixtral_output, mixtral_router_logits = mixtral_moe(input_data)
sparse_output, _ = sparse_moe(input_data) sparse_output, sparse_router_logits = sparse_moe(input_data)
# Compute the difference between the outputs and router logits # Compute the difference between the outputs
output_diff = torch.abs(mixtral_output - sparse_output).mean().item() output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
print(output_diff, router_diff)
# Define the tolerance for the difference # Define the tolerance for the difference
tolerance = 0.1 tolerance = 0.05
# # Check if the difference is within the tolerance # # Check if the difference is within the tolerance
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}" assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
assert router_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"