This commit is contained in:
Dan Saunders
2025-09-18 11:20:08 -04:00
parent bbf1f14ca4
commit 01b6792c2e

View File

@@ -26,54 +26,30 @@ def available() -> bool:
return False
def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
impls: List[torch.nn.Module] = []
for exp in experts_module:
impls.append(getattr(exp, "mlp", getattr(exp, "ffn", exp)))
return impls
def _stack_weights(
experts_module,
names: Tuple[str, ...],
*,
key: str,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
attr = f"_ax_grouped_{key}"
cached = getattr(experts_module, attr, None)
if cached is not None and cached.dtype == dtype and cached.device == device:
return cached
tensors: List[torch.Tensor] = []
for exp in experts_module:
mod = getattr(exp, "mlp", getattr(exp, "ffn", exp))
for mod in _iter_expert_impls(experts_module):
parts = [getattr(mod, name).weight.t() for name in names]
tensors.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1))
stacked = (
return (
torch.stack(tensors, dim=0)
.to(device=device, dtype=dtype, non_blocking=True)
.contiguous()
)
setattr(experts_module, attr, stacked)
return stacked
def _call_grouped_mm(
As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype
) -> List[torch.Tensor]:
if not As:
return []
if dtype not in (torch.bfloat16, torch.float16):
raise RuntimeError(f"unsupported dtype {dtype} for grouped_mm")
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
device = As2[0].device
lengths = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
offsets = torch.cumsum(lengths, dim=0).to(torch.int32)
Y_cat = torch._grouped_mm(torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets)
outs: List[torch.Tensor] = []
start = 0
for size in lengths.tolist():
outs.append(Y_cat[start : start + size])
start += size
return outs
def moe_ffn_forward_grouped(
@@ -99,30 +75,30 @@ def moe_ffn_forward_grouped(
)
return None, None
sample_mod = getattr(
experts_module[0], "mlp", getattr(experts_module[0], "ffn", experts_module[0])
)
for suffix in ("w13", "w2"):
attr = f"_ax_grouped_{suffix}"
if hasattr(experts_module, attr):
delattr(experts_module, attr)
expert_impls = _iter_expert_impls(experts_module)
sample_mod = expert_impls[0]
if (
hasattr(sample_mod, "w1")
and hasattr(sample_mod, "w3")
and hasattr(sample_mod, "w2")
):
w13 = _stack_weights(
experts_module, ("w1", "w3"), key="w13", dtype=expert_dtype, device=device
)
w2 = _stack_weights(
experts_module, ("w2",), key="w2", dtype=expert_dtype, device=device
experts_module, ("w1", "w3"), dtype=expert_dtype, device=device
)
w2 = _stack_weights(experts_module, ("w2",), dtype=expert_dtype, device=device)
else:
if hasattr(sample_mod, "gate_up_proj"):
names13: Tuple[str, ...] = ("gate_up_proj",)
else:
names13 = ("up_proj", "gate_proj")
w13 = _stack_weights(
experts_module, names13, key="w13", dtype=expert_dtype, device=device
)
w13 = _stack_weights(experts_module, names13, dtype=expert_dtype, device=device)
w2 = _stack_weights(
experts_module, ("down_proj",), key="w2", dtype=expert_dtype, device=device
experts_module, ("down_proj",), dtype=expert_dtype, device=device
)
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
@@ -133,41 +109,45 @@ def moe_ffn_forward_grouped(
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
flat_idx = topk_idx.view(-1)
x_rep = x_flat.repeat_interleave(top_k, dim=0)
num_experts = len(expert_impls)
if flat_idx.numel() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
as_list: List[torch.Tensor] = []
bs_list: List[torch.Tensor] = []
slices: List[Tuple[int, torch.Tensor]] = []
for i, _ in enumerate(experts_module):
sel = flat_idx == i
if sel.any():
as_list.append(x_rep[sel])
bs_list.append(w13[i])
slices.append((i, sel))
assignments = torch.bincount(flat_idx, minlength=num_experts)
if assignments.sum() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
if not as_list:
return torch.zeros_like(x_flat).view(bsz, seqlen, hdim), router_logits
perm = torch.argsort(flat_idx, stable=True)
token_indices_sorted = perm // top_k
scores_sorted = topk_weight.view(-1)[perm]
up_out = _call_grouped_mm(as_list, bs_list, expert_dtype)
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
routed_input = torch.gather(x_flat, 0, gather_index).contiguous()
down_inputs: List[torch.Tensor] = []
down_weights: List[torch.Tensor] = []
buf = torch.empty_like(x_rep)
for (i, _sel), Yi in zip(slices, up_out, strict=False):
mid = Yi.shape[-1] // 2
hidden = F.silu(Yi[:, :mid]) * Yi[:, mid:]
down_inputs.append(hidden)
down_weights.append(w2[i])
offsets = torch.cumsum(assignments.to(device=device, dtype=torch.int32), dim=0)
if offsets[-1].item() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
down_out = _call_grouped_mm(down_inputs, down_weights, expert_dtype)
mid = w13.shape[-1] // 2
w_gate = w13[..., :mid]
w_up = w13[..., mid:]
for (_i, sel), tensor in zip(slices, down_out, strict=False):
buf[sel] = tensor
w_gate_t = w_gate.transpose(-2, -1).contiguous()
w_up_t = w_up.transpose(-2, -1).contiguous()
w2_t = w2.transpose(-2, -1).contiguous()
combined = (
(buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1))
.sum(dim=1)
.to(torch.bfloat16)
)
routed_in = routed_input.to(expert_dtype)
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
activated = F.silu(gate_out) * up_out
down_out = torch._grouped_mm(activated, w2_t, offs=offsets)
weights_fp32 = scores_sorted.unsqueeze(-1).to(torch.float32)
weighted = (down_out.to(torch.float32) * weights_fp32).to(expert_dtype)
combined = torch.zeros_like(x_flat)
combined.scatter_add_(0, gather_index, weighted)
return combined.view(bsz, seqlen, hdim), router_logits