Initialize ParallelExperts on device of first expert
This commit is contained in:
@@ -123,9 +123,11 @@ def parallel_linear(inputs, expert_weights, k,
|
||||
return results
|
||||
|
||||
class ParallelExperts(nn.Module):
|
||||
def __init__(self, num_experts, input_size, output_size) -> None:
|
||||
def __init__(self, num_experts, input_size, output_size, device) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(num_experts, output_size, input_size, device=device)
|
||||
)
|
||||
self.num_experts = num_experts
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
|
||||
@@ -26,12 +26,14 @@ class FusedExperts(nn.Module):
|
||||
MLP of type Gated-Linear Unit, typically with a SiLU activation function.
|
||||
"""
|
||||
super(FusedExperts, self).__init__()
|
||||
expert_device = experts[0].w1.weight.device
|
||||
output_expert_device = experts[0].w2.weight.device
|
||||
|
||||
self.num_experts = num_experts
|
||||
self.hidden_dim = hidden_dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim)
|
||||
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim)
|
||||
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, expert_device)
|
||||
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, output_expert_device)
|
||||
self.top_k = min(top_k, self.num_experts)
|
||||
self.activation = activation
|
||||
|
||||
|
||||
Reference in New Issue
Block a user