implement post training

This commit is contained in:
Eric Hartford
2024-03-15 13:16:06 -07:00
parent 035e680631
commit 301cc4c006

View File

@@ -21,14 +21,28 @@ class SparseMoeBlock(nn.Module):
)
def _post_training(self, model, name):
# get original weights back: reverse the concat + stack in the fused experts
# Get original weights back: reverse the concat + stack in the fused experts
w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1)
w2s = torch.unbind(self.experts.output_experts.weight, dim=0)
# TODO: recreate MoE class with original weights
# Recreate the MoE class with original weights
experts = []
for i in range(self.num_experts):
pass
expert = nn.Sequential(
nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False),
self.experts.activation,
nn.Linear(self.ffn_dim, self.hidden_dim, bias=False),
)
expert[0].weight.data = torch.cat([w1s[i], w3s[i]], dim=0)
expert[2].weight.data = w2s[i]
experts.append(expert)
# Create a new MoE module with the recreated experts
moe = nn.ModuleList(experts)
# Replace the fused experts with the recreated MoE module
setattr(model, name.replace("experts", "moe"), moe)
delattr(model, name)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape