code review feedback

This commit is contained in:
Eric Hartford
2024-03-15 14:10:22 -07:00
parent 301cc4c006
commit 9c221a6761

View File

@@ -25,24 +25,33 @@ class SparseMoeBlock(nn.Module):
w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1) w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1)
w2s = torch.unbind(self.experts.output_experts.weight, dim=0) w2s = torch.unbind(self.experts.output_experts.weight, dim=0)
# Recreate the MoE class with original weights # Recreate the structure of the original MixtralSparseMoeBlock
experts = [] original_moe = nn.Module()
for i in range(self.num_experts): original_moe.hidden_dim = self.hidden_dim
expert = nn.Sequential( original_moe.ffn_dim = self.ffn_dim
nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False), original_moe.num_experts = self.num_experts
self.experts.activation, original_moe.top_k = self.top_k
nn.Linear(self.ffn_dim, self.hidden_dim, bias=False),
)
expert[0].weight.data = torch.cat([w1s[i], w3s[i]], dim=0)
expert[2].weight.data = w2s[i]
experts.append(expert)
# Create a new MoE module with the recreated experts # Recreate the gating module
moe = nn.ModuleList(experts) original_moe.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
original_moe.gate.weight.data = self.gate.weight.data
# Replace the fused experts with the recreated MoE module # Recreate the experts as a ModuleList
setattr(model, name.replace("experts", "moe"), moe) original_moe.experts = nn.ModuleList()
delattr(model, name) for expert_idx in range(self.num_experts):
expert = nn.Module()
expert.w1 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False)
expert.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
expert.w3 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False)
expert.act_fn = self.experts.activation
expert.w1.weight.data = torch.cat([w1s[expert_idx], w3s[expert_idx]], dim=0)
expert.w2.weight.data = w2s[expert_idx]
original_moe.experts.append(expert)
# Replace the SparseMoeBlock with the recreated MixtralSparseMoeBlock structure
setattr(model, name, original_moe)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape batch_size, sequence_length, hidden_dim = hidden_states.shape