fix: make aux-free mixtral adapter GPU-safe

This commit is contained in:
lhl
2025-11-11 17:00:37 +00:00
committed by Wing Lian
parent ad0c825bcb
commit 966a4555db

View File

@@ -164,9 +164,11 @@ class MixtralAdapter(BaseMoEAdapter):
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
for expert_idx_tensor in expert_hit:
expert_idx = int(expert_idx_tensor.squeeze().item())
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
mask = expert_mask[expert_idx].squeeze(0)
idx, top_x = torch.where(mask)
current_state = flat_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype))