Initialize ParallelExperts on device of first expert
This commit is contained in:
@@ -123,9 +123,11 @@ def parallel_linear(inputs, expert_weights, k,
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
class ParallelExperts(nn.Module):
|
class ParallelExperts(nn.Module):
|
||||||
def __init__(self, num_experts, input_size, output_size) -> None:
|
def __init__(self, num_experts, input_size, output_size, device) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
|
self.weight = nn.Parameter(
|
||||||
|
torch.empty(num_experts, output_size, input_size, device=device)
|
||||||
|
)
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
|
|||||||
@@ -26,12 +26,14 @@ class FusedExperts(nn.Module):
|
|||||||
MLP of type Gated-Linear Unit, typically with a SiLU activation function.
|
MLP of type Gated-Linear Unit, typically with a SiLU activation function.
|
||||||
"""
|
"""
|
||||||
super(FusedExperts, self).__init__()
|
super(FusedExperts, self).__init__()
|
||||||
|
expert_device = experts[0].w1.weight.device
|
||||||
|
output_expert_device = experts[0].w2.weight.device
|
||||||
|
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.hidden_dim = hidden_dim
|
self.hidden_dim = hidden_dim
|
||||||
self.ffn_dim = ffn_dim
|
self.ffn_dim = ffn_dim
|
||||||
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim)
|
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, expert_device)
|
||||||
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim)
|
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, output_expert_device)
|
||||||
self.top_k = min(top_k, self.num_experts)
|
self.top_k = min(top_k, self.num_experts)
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user