Initialize ParallelExperts on device of first expert

This commit is contained in:
Casper
2024-03-17 19:51:31 +01:00
parent 2ea75b4160
commit 884d81331e
2 changed files with 8 additions and 4 deletions

View File

@@ -123,9 +123,11 @@ def parallel_linear(inputs, expert_weights, k,
return results
class ParallelExperts(nn.Module):
def __init__(self, num_experts, input_size, output_size) -> None:
def __init__(self, num_experts, input_size, output_size, device) -> None:
super().__init__()
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
self.weight = nn.Parameter(
torch.empty(num_experts, output_size, input_size, device=device)
)
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size

View File

@@ -26,12 +26,14 @@ class FusedExperts(nn.Module):
MLP of type Gated-Linear Unit, typically with a SiLU activation function.
"""
super(FusedExperts, self).__init__()
expert_device = experts[0].w1.weight.device
output_expert_device = experts[0].w2.weight.device
self.num_experts = num_experts
self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim)
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim)
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, expert_device)
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, output_expert_device)
self.top_k = min(top_k, self.num_experts)
self.activation = activation