This commit is contained in:
Dan Saunders
2025-09-17 16:16:41 -04:00
parent 180920c7bf
commit 38b890a36b

View File

@@ -91,22 +91,35 @@ def moe_ffn_forward_grouped(
global LAST_ERROR
LAST_ERROR = None
bsz, seqlen, hdim = hidden_states.shape
compute_dtype = gate_linear.weight.dtype
if hidden_states.dtype != compute_dtype:
hidden_states = hidden_states.to(dtype=compute_dtype)
routing_dtype = gate_linear.weight.dtype
use_mixed_router = (
hidden_states.device.type == "cuda" and routing_dtype == torch.float32
)
if use_mixed_router:
x_router = hidden_states.to(dtype=routing_dtype)
router_logits = gate_linear(x_router)
else:
if hidden_states.dtype != routing_dtype:
hidden_states = hidden_states.to(dtype=routing_dtype)
x = hidden_states.view(-1, hdim)
router_logits = gate_linear(x)
if router_logits.dtype != routing_dtype:
router_logits = router_logits.to(dtype=routing_dtype)
x = hidden_states.view(-1, hdim)
router_logits = gate_linear(x)
# top-k routing executed in torch to avoid extra dependencies
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype)
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
flat_idx = topk_idx.view(-1)
x_rep = x.repeat_interleave(top_k, dim=0)
E = _num_experts(experts_module)
dev, dt = x.device, x.dtype
dev = hidden_states.device
dt: torch.dtype = hidden_states.dtype
first = experts_module[0]
is_mixtral = _is_mixtral_layout(first)
@@ -134,7 +147,7 @@ def moe_ffn_forward_grouped(
LAST_ERROR = "unsupported expert layout"
return None, None
def _resolve_expert(idx: int):
def _resolve_expert(idx: int) -> torch.nn.Module:
expert = experts_module[idx]
if nested_attr is None:
return expert
@@ -214,6 +227,18 @@ def moe_ffn_forward_grouped(
W13 = experts_module._stacked_w13
W2 = experts_module._stacked_w2
dt = W13.dtype
if router_logits.dtype != dt:
router_logits = router_logits.to(dtype=dt)
if x.dtype != dt:
x = x.to(dtype=dt)
flat_idx = topk_idx.view(-1)
if topk_weight.dtype != dt:
topk_weight = topk_weight.to(dtype=dt)
x_rep = x.repeat_interleave(top_k, dim=0)
if x_rep.dtype != dt:
x_rep = x_rep.to(dtype=dt)
As: List[torch.Tensor] = []
Bs: List[torch.Tensor] = []
expert_slices: List[Tuple[int, torch.Tensor]] = []
@@ -272,7 +297,7 @@ def moe_ffn_forward_grouped(
As2: List[torch.Tensor] = []
Bs2: List[torch.Tensor] = []
y_buf = torch.empty_like(x_rep)
y_buf = torch.empty_like(x_rep, dtype=dt)
for (i, _sel), Yi in zip(expert_slices, Y_list, strict=False):
I2 = Yi.shape[-1] // 2
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]