From 884d81331ef7bcc589311539622130c770d05c19 Mon Sep 17 00:00:00 2001 From: Casper Date: Sun, 17 Mar 2024 19:51:31 +0100 Subject: [PATCH] Initialize ParallelExperts on device of first expert --- src/axolotl/monkeypatch/moe/linear.py | 6 ++++-- src/axolotl/monkeypatch/moe/mlp.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/axolotl/monkeypatch/moe/linear.py b/src/axolotl/monkeypatch/moe/linear.py index 1d2c3e101..5826132da 100644 --- a/src/axolotl/monkeypatch/moe/linear.py +++ b/src/axolotl/monkeypatch/moe/linear.py @@ -123,9 +123,11 @@ def parallel_linear(inputs, expert_weights, k, return results class ParallelExperts(nn.Module): - def __init__(self, num_experts, input_size, output_size) -> None: + def __init__(self, num_experts, input_size, output_size, device) -> None: super().__init__() - self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) + self.weight = nn.Parameter( + torch.empty(num_experts, output_size, input_size, device=device) + ) self.num_experts = num_experts self.input_size = input_size self.output_size = output_size diff --git a/src/axolotl/monkeypatch/moe/mlp.py b/src/axolotl/monkeypatch/moe/mlp.py index fb75f2740..f1a1328a7 100644 --- a/src/axolotl/monkeypatch/moe/mlp.py +++ b/src/axolotl/monkeypatch/moe/mlp.py @@ -26,12 +26,14 @@ class FusedExperts(nn.Module): MLP of type Gated-Linear Unit, typically with a SiLU activation function. """ super(FusedExperts, self).__init__() + expert_device = experts[0].w1.weight.device + output_expert_device = experts[0].w2.weight.device self.num_experts = num_experts self.hidden_dim = hidden_dim self.ffn_dim = ffn_dim - self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim) - self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim) + self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, expert_device) + self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, output_expert_device) self.top_k = min(top_k, self.num_experts) self.activation = activation