fix
This commit is contained in:
@@ -96,30 +96,27 @@ def moe_ffn_forward_grouped(
|
||||
hidden_states.device.type == "cuda" and routing_dtype == torch.float32
|
||||
)
|
||||
|
||||
x_base = hidden_states.view(-1, hdim)
|
||||
if use_mixed_router:
|
||||
x_router = hidden_states.to(dtype=routing_dtype)
|
||||
router_logits = gate_linear(x_router)
|
||||
x_router = x_base.to(dtype=routing_dtype)
|
||||
else:
|
||||
if hidden_states.dtype != routing_dtype:
|
||||
hidden_states = hidden_states.to(dtype=routing_dtype)
|
||||
x = hidden_states.view(-1, hdim)
|
||||
router_logits = gate_linear(x)
|
||||
x_router = x_base
|
||||
if x_router.dtype != routing_dtype:
|
||||
x_router = x_router.to(dtype=routing_dtype)
|
||||
|
||||
router_logits = gate_linear(x_router)
|
||||
if router_logits.dtype != routing_dtype:
|
||||
router_logits = router_logits.to(dtype=routing_dtype)
|
||||
|
||||
x = hidden_states.view(-1, hdim)
|
||||
x = x_base
|
||||
|
||||
# top-k routing executed in torch to avoid extra dependencies
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
||||
|
||||
flat_idx = topk_idx.view(-1)
|
||||
|
||||
E = _num_experts(experts_module)
|
||||
dev = hidden_states.device
|
||||
dt: torch.dtype = hidden_states.dtype
|
||||
first = experts_module[0]
|
||||
|
||||
is_mixtral = _is_mixtral_layout(first)
|
||||
@@ -157,6 +154,7 @@ def moe_ffn_forward_grouped(
|
||||
return nested_mod
|
||||
|
||||
if is_mixtral:
|
||||
dt: torch.dtype = first.w1.weight.dtype # type: ignore[assignment]
|
||||
if (
|
||||
not hasattr(experts_module, "_stacked_w1")
|
||||
or experts_module._stacked_w1.device != dev
|
||||
@@ -187,6 +185,13 @@ def moe_ffn_forward_grouped(
|
||||
W13 = experts_module._stacked_w13
|
||||
W2 = experts_module._stacked_w2
|
||||
else:
|
||||
sample_mod = _resolve_expert(0)
|
||||
if hasattr(sample_mod, "gate_up_proj"):
|
||||
dt = sample_mod.gate_up_proj.weight.dtype # type: ignore[assignment]
|
||||
elif hasattr(sample_mod, "up_proj"):
|
||||
dt = sample_mod.up_proj.weight.dtype # type: ignore[assignment]
|
||||
else:
|
||||
dt = sample_mod.down_proj.weight.dtype # type: ignore[assignment]
|
||||
if (
|
||||
not hasattr(experts_module, "_stacked_w13")
|
||||
or experts_module._stacked_w13.device != dev
|
||||
|
||||
Reference in New Issue
Block a user