Compare commits
2 Commits
scatter_mo
...
scatter_mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c221a6761 | ||
|
|
301cc4c006 |
@@ -21,14 +21,37 @@ class SparseMoeBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _post_training(self, model, name):
|
def _post_training(self, model, name):
|
||||||
# get original weights back: reverse the concat + stack in the fused experts
|
# Get original weights back: reverse the concat + stack in the fused experts
|
||||||
w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1)
|
w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1)
|
||||||
w2s = torch.unbind(self.experts.output_experts.weight, dim=0)
|
w2s = torch.unbind(self.experts.output_experts.weight, dim=0)
|
||||||
|
|
||||||
# TODO: recreate MoE class with original weights
|
# Recreate the structure of the original MixtralSparseMoeBlock
|
||||||
experts = []
|
original_moe = nn.Module()
|
||||||
for i in range(self.num_experts):
|
original_moe.hidden_dim = self.hidden_dim
|
||||||
pass
|
original_moe.ffn_dim = self.ffn_dim
|
||||||
|
original_moe.num_experts = self.num_experts
|
||||||
|
original_moe.top_k = self.top_k
|
||||||
|
|
||||||
|
# Recreate the gating module
|
||||||
|
original_moe.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
||||||
|
original_moe.gate.weight.data = self.gate.weight.data
|
||||||
|
|
||||||
|
# Recreate the experts as a ModuleList
|
||||||
|
original_moe.experts = nn.ModuleList()
|
||||||
|
for expert_idx in range(self.num_experts):
|
||||||
|
expert = nn.Module()
|
||||||
|
expert.w1 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False)
|
||||||
|
expert.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
||||||
|
expert.w3 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False)
|
||||||
|
expert.act_fn = self.experts.activation
|
||||||
|
|
||||||
|
expert.w1.weight.data = torch.cat([w1s[expert_idx], w3s[expert_idx]], dim=0)
|
||||||
|
expert.w2.weight.data = w2s[expert_idx]
|
||||||
|
|
||||||
|
original_moe.experts.append(expert)
|
||||||
|
|
||||||
|
# Replace the SparseMoeBlock with the recreated MixtralSparseMoeBlock structure
|
||||||
|
setattr(model, name, original_moe)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||||
|
|||||||
Reference in New Issue
Block a user