From 301cc4c006be0a41c4b446d38f69913ffd53f0ae Mon Sep 17 00:00:00 2001 From: Eric Hartford Date: Fri, 15 Mar 2024 13:16:06 -0700 Subject: [PATCH] implement post training --- src/axolotl/monkeypatch/moe/moe.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/axolotl/monkeypatch/moe/moe.py b/src/axolotl/monkeypatch/moe/moe.py index f2a506883..ee3d9bb28 100644 --- a/src/axolotl/monkeypatch/moe/moe.py +++ b/src/axolotl/monkeypatch/moe/moe.py @@ -21,14 +21,28 @@ class SparseMoeBlock(nn.Module): ) def _post_training(self, model, name): - # get original weights back: reverse the concat + stack in the fused experts + # Get original weights back: reverse the concat + stack in the fused experts w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1) w2s = torch.unbind(self.experts.output_experts.weight, dim=0) - # TODO: recreate MoE class with original weights + # Recreate the MoE class with original weights experts = [] for i in range(self.num_experts): - pass + expert = nn.Sequential( + nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False), + self.experts.activation, + nn.Linear(self.ffn_dim, self.hidden_dim, bias=False), + ) + expert[0].weight.data = torch.cat([w1s[i], w3s[i]], dim=0) + expert[2].weight.data = w2s[i] + experts.append(expert) + + # Create a new MoE module with the recreated experts + moe = nn.ModuleList(experts) + + # Replace the fused experts with the recreated MoE module + setattr(model, name.replace("experts", "moe"), moe) + delattr(model, name) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape