From 966a4555dbebb224b2d424e827bb57890d6b20a1 Mon Sep 17 00:00:00 2001 From: lhl Date: Tue, 11 Nov 2025 17:00:37 +0000 Subject: [PATCH] fix: make aux-free mixtral adapter GPU-safe --- src/axolotl/integrations/aux_free_router/adapters.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/axolotl/integrations/aux_free_router/adapters.py b/src/axolotl/integrations/aux_free_router/adapters.py index ea49c6936..ac9ed1851 100644 --- a/src/axolotl/integrations/aux_free_router/adapters.py +++ b/src/axolotl/integrations/aux_free_router/adapters.py @@ -164,9 +164,11 @@ class MixtralAdapter(BaseMoEAdapter): expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: + for expert_idx_tensor in expert_hit: + expert_idx = int(expert_idx_tensor.squeeze().item()) expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + mask = expert_mask[expert_idx].squeeze(0) + idx, top_x = torch.where(mask) current_state = flat_states[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype))