fix: make aux-free mixtral adapter GPU-safe
This commit is contained in:
@@ -164,9 +164,11 @@ class MixtralAdapter(BaseMoEAdapter):
|
||||
|
||||
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
for expert_idx_tensor in expert_hit:
|
||||
expert_idx = int(expert_idx_tensor.squeeze().item())
|
||||
expert_layer = self.experts[expert_idx]
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
mask = expert_mask[expert_idx].squeeze(0)
|
||||
idx, top_x = torch.where(mask)
|
||||
current_state = flat_states[None, top_x].reshape(-1, hidden_dim)
|
||||
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype))
|
||||
|
||||
Reference in New Issue
Block a user