diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 72efde43c..2e947c83f 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -96,30 +96,27 @@ def moe_ffn_forward_grouped( hidden_states.device.type == "cuda" and routing_dtype == torch.float32 ) + x_base = hidden_states.view(-1, hdim) if use_mixed_router: - x_router = hidden_states.to(dtype=routing_dtype) - router_logits = gate_linear(x_router) + x_router = x_base.to(dtype=routing_dtype) else: - if hidden_states.dtype != routing_dtype: - hidden_states = hidden_states.to(dtype=routing_dtype) - x = hidden_states.view(-1, hdim) - router_logits = gate_linear(x) + x_router = x_base + if x_router.dtype != routing_dtype: + x_router = x_router.to(dtype=routing_dtype) + router_logits = gate_linear(x_router) if router_logits.dtype != routing_dtype: router_logits = router_logits.to(dtype=routing_dtype) - x = hidden_states.view(-1, hdim) + x = x_base # top-k routing executed in torch to avoid extra dependencies routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) - flat_idx = topk_idx.view(-1) - E = _num_experts(experts_module) dev = hidden_states.device - dt: torch.dtype = hidden_states.dtype first = experts_module[0] is_mixtral = _is_mixtral_layout(first) @@ -157,6 +154,7 @@ def moe_ffn_forward_grouped( return nested_mod if is_mixtral: + dt: torch.dtype = first.w1.weight.dtype # type: ignore[assignment] if ( not hasattr(experts_module, "_stacked_w1") or experts_module._stacked_w1.device != dev @@ -187,6 +185,13 @@ def moe_ffn_forward_grouped( W13 = experts_module._stacked_w13 W2 = experts_module._stacked_w2 else: + sample_mod = _resolve_expert(0) + if hasattr(sample_mod, "gate_up_proj"): + dt = sample_mod.gate_up_proj.weight.dtype # type: ignore[assignment] + elif hasattr(sample_mod, "up_proj"): + dt = sample_mod.up_proj.weight.dtype # type: ignore[assignment] + else: + dt = sample_mod.down_proj.weight.dtype # type: ignore[assignment] if ( not hasattr(experts_module, "_stacked_w13") or experts_module._stacked_w13.device != dev