This commit is contained in:
Dan Saunders
2025-09-18 11:20:08 -04:00
parent bbf1f14ca4
commit 01b6792c2e

View File

@@ -26,54 +26,30 @@ def available() -> bool:
return False return False
def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
impls: List[torch.nn.Module] = []
for exp in experts_module:
impls.append(getattr(exp, "mlp", getattr(exp, "ffn", exp)))
return impls
def _stack_weights( def _stack_weights(
experts_module, experts_module,
names: Tuple[str, ...], names: Tuple[str, ...],
*, *,
key: str,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> torch.Tensor: ) -> torch.Tensor:
attr = f"_ax_grouped_{key}"
cached = getattr(experts_module, attr, None)
if cached is not None and cached.dtype == dtype and cached.device == device:
return cached
tensors: List[torch.Tensor] = [] tensors: List[torch.Tensor] = []
for exp in experts_module: for mod in _iter_expert_impls(experts_module):
mod = getattr(exp, "mlp", getattr(exp, "ffn", exp))
parts = [getattr(mod, name).weight.t() for name in names] parts = [getattr(mod, name).weight.t() for name in names]
tensors.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1)) tensors.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1))
stacked = ( return (
torch.stack(tensors, dim=0) torch.stack(tensors, dim=0)
.to(device=device, dtype=dtype, non_blocking=True) .to(device=device, dtype=dtype, non_blocking=True)
.contiguous() .contiguous()
) )
setattr(experts_module, attr, stacked)
return stacked
def _call_grouped_mm(
As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype
) -> List[torch.Tensor]:
if not As:
return []
if dtype not in (torch.bfloat16, torch.float16):
raise RuntimeError(f"unsupported dtype {dtype} for grouped_mm")
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
device = As2[0].device
lengths = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
offsets = torch.cumsum(lengths, dim=0).to(torch.int32)
Y_cat = torch._grouped_mm(torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets)
outs: List[torch.Tensor] = []
start = 0
for size in lengths.tolist():
outs.append(Y_cat[start : start + size])
start += size
return outs
def moe_ffn_forward_grouped( def moe_ffn_forward_grouped(
@@ -99,30 +75,30 @@ def moe_ffn_forward_grouped(
) )
return None, None return None, None
sample_mod = getattr( for suffix in ("w13", "w2"):
experts_module[0], "mlp", getattr(experts_module[0], "ffn", experts_module[0]) attr = f"_ax_grouped_{suffix}"
) if hasattr(experts_module, attr):
delattr(experts_module, attr)
expert_impls = _iter_expert_impls(experts_module)
sample_mod = expert_impls[0]
if ( if (
hasattr(sample_mod, "w1") hasattr(sample_mod, "w1")
and hasattr(sample_mod, "w3") and hasattr(sample_mod, "w3")
and hasattr(sample_mod, "w2") and hasattr(sample_mod, "w2")
): ):
w13 = _stack_weights( w13 = _stack_weights(
experts_module, ("w1", "w3"), key="w13", dtype=expert_dtype, device=device experts_module, ("w1", "w3"), dtype=expert_dtype, device=device
)
w2 = _stack_weights(
experts_module, ("w2",), key="w2", dtype=expert_dtype, device=device
) )
w2 = _stack_weights(experts_module, ("w2",), dtype=expert_dtype, device=device)
else: else:
if hasattr(sample_mod, "gate_up_proj"): if hasattr(sample_mod, "gate_up_proj"):
names13: Tuple[str, ...] = ("gate_up_proj",) names13: Tuple[str, ...] = ("gate_up_proj",)
else: else:
names13 = ("up_proj", "gate_proj") names13 = ("up_proj", "gate_proj")
w13 = _stack_weights( w13 = _stack_weights(experts_module, names13, dtype=expert_dtype, device=device)
experts_module, names13, key="w13", dtype=expert_dtype, device=device
)
w2 = _stack_weights( w2 = _stack_weights(
experts_module, ("down_proj",), key="w2", dtype=expert_dtype, device=device experts_module, ("down_proj",), dtype=expert_dtype, device=device
) )
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype) x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
@@ -133,41 +109,45 @@ def moe_ffn_forward_grouped(
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
flat_idx = topk_idx.view(-1) flat_idx = topk_idx.view(-1)
x_rep = x_flat.repeat_interleave(top_k, dim=0) num_experts = len(expert_impls)
if flat_idx.numel() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
as_list: List[torch.Tensor] = [] assignments = torch.bincount(flat_idx, minlength=num_experts)
bs_list: List[torch.Tensor] = [] if assignments.sum() == 0:
slices: List[Tuple[int, torch.Tensor]] = [] zero = torch.zeros_like(x_flat)
for i, _ in enumerate(experts_module): return zero.view(bsz, seqlen, hdim), router_logits
sel = flat_idx == i
if sel.any():
as_list.append(x_rep[sel])
bs_list.append(w13[i])
slices.append((i, sel))
if not as_list: perm = torch.argsort(flat_idx, stable=True)
return torch.zeros_like(x_flat).view(bsz, seqlen, hdim), router_logits token_indices_sorted = perm // top_k
scores_sorted = topk_weight.view(-1)[perm]
up_out = _call_grouped_mm(as_list, bs_list, expert_dtype) gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
routed_input = torch.gather(x_flat, 0, gather_index).contiguous()
down_inputs: List[torch.Tensor] = [] offsets = torch.cumsum(assignments.to(device=device, dtype=torch.int32), dim=0)
down_weights: List[torch.Tensor] = [] if offsets[-1].item() == 0:
buf = torch.empty_like(x_rep) zero = torch.zeros_like(x_flat)
for (i, _sel), Yi in zip(slices, up_out, strict=False): return zero.view(bsz, seqlen, hdim), router_logits
mid = Yi.shape[-1] // 2
hidden = F.silu(Yi[:, :mid]) * Yi[:, mid:]
down_inputs.append(hidden)
down_weights.append(w2[i])
down_out = _call_grouped_mm(down_inputs, down_weights, expert_dtype) mid = w13.shape[-1] // 2
w_gate = w13[..., :mid]
w_up = w13[..., mid:]
for (_i, sel), tensor in zip(slices, down_out, strict=False): w_gate_t = w_gate.transpose(-2, -1).contiguous()
buf[sel] = tensor w_up_t = w_up.transpose(-2, -1).contiguous()
w2_t = w2.transpose(-2, -1).contiguous()
combined = ( routed_in = routed_input.to(expert_dtype)
(buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1)) gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
.sum(dim=1) up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
.to(torch.bfloat16) activated = F.silu(gate_out) * up_out
) down_out = torch._grouped_mm(activated, w2_t, offs=offsets)
weights_fp32 = scores_sorted.unsqueeze(-1).to(torch.float32)
weighted = (down_out.to(torch.float32) * weights_fp32).to(expert_dtype)
combined = torch.zeros_like(x_flat)
combined.scatter_add_(0, gather_index, weighted)
return combined.view(bsz, seqlen, hdim), router_logits return combined.view(bsz, seqlen, hdim), router_logits