This commit is contained in:
Dan Saunders
2025-09-17 16:24:29 -04:00
parent 38b890a36b
commit 129db67705

View File

@@ -96,30 +96,27 @@ def moe_ffn_forward_grouped(
hidden_states.device.type == "cuda" and routing_dtype == torch.float32
)
x_base = hidden_states.view(-1, hdim)
if use_mixed_router:
x_router = hidden_states.to(dtype=routing_dtype)
router_logits = gate_linear(x_router)
x_router = x_base.to(dtype=routing_dtype)
else:
if hidden_states.dtype != routing_dtype:
hidden_states = hidden_states.to(dtype=routing_dtype)
x = hidden_states.view(-1, hdim)
router_logits = gate_linear(x)
x_router = x_base
if x_router.dtype != routing_dtype:
x_router = x_router.to(dtype=routing_dtype)
router_logits = gate_linear(x_router)
if router_logits.dtype != routing_dtype:
router_logits = router_logits.to(dtype=routing_dtype)
x = hidden_states.view(-1, hdim)
x = x_base
# top-k routing executed in torch to avoid extra dependencies
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
flat_idx = topk_idx.view(-1)
E = _num_experts(experts_module)
dev = hidden_states.device
dt: torch.dtype = hidden_states.dtype
first = experts_module[0]
is_mixtral = _is_mixtral_layout(first)
@@ -157,6 +154,7 @@ def moe_ffn_forward_grouped(
return nested_mod
if is_mixtral:
dt: torch.dtype = first.w1.weight.dtype # type: ignore[assignment]
if (
not hasattr(experts_module, "_stacked_w1")
or experts_module._stacked_w1.device != dev
@@ -187,6 +185,13 @@ def moe_ffn_forward_grouped(
W13 = experts_module._stacked_w13
W2 = experts_module._stacked_w2
else:
sample_mod = _resolve_expert(0)
if hasattr(sample_mod, "gate_up_proj"):
dt = sample_mod.gate_up_proj.weight.dtype # type: ignore[assignment]
elif hasattr(sample_mod, "up_proj"):
dt = sample_mod.up_proj.weight.dtype # type: ignore[assignment]
else:
dt = sample_mod.down_proj.weight.dtype # type: ignore[assignment]
if (
not hasattr(experts_module, "_stacked_w13")
or experts_module._stacked_w13.device != dev