diff --git a/src/axolotl/models/mixtral/modeling_moe_mistral.py b/src/axolotl/models/mixtral/modeling_moe_mistral.py index 5df91e3f4..6f1fb7a4a 100644 --- a/src/axolotl/models/mixtral/modeling_moe_mistral.py +++ b/src/axolotl/models/mixtral/modeling_moe_mistral.py @@ -215,23 +215,22 @@ class MoE(nn.Module): ): super().__init__() self.config = config - num_experts = config.num_experts - self.experts = nn.ModuleList([FeedForward(config) for i in range(num_experts)]) - self.gate = nn.Linear(config.hidden_size, num_experts, bias=False) - self.num_experts_per_token = config.num_experts_per_token + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = nn.ModuleList( + [FeedForward(config) for i in range(config.num_experts)] + ) def forward(self, x): orig_shape = x.shape x = x.view(-1, x.shape[-1]) - scores = self.gate(x) + scores = self.gate(x).softmax(dim=-1) expert_weights, expert_indices = torch.topk( - scores, self.num_experts_per_token, dim=-1 + scores, self.config.num_experts_per_token, dim=-1 ) - expert_weights = expert_weights.softmax(dim=-1) flat_expert_indices = expert_indices.view(-1) - x = x.repeat_interleave(self.num_experts_per_token, dim=0) + x = x.repeat_interleave(self.config.num_experts_per_token, dim=0) y = torch.empty_like(x) for i, expert in enumerate(self.experts): y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])