fix perf degradation

This commit is contained in:
Dan Saunders
2025-09-17 18:20:37 -04:00
parent fd87eed501
commit 03d4c2683e
3 changed files with 61 additions and 32 deletions

View File

@@ -29,7 +29,9 @@ def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
try: try:
selected = MOEBackend(choice) selected = MOEBackend(choice)
except ValueError: except ValueError:
warnings.warn(f"Unknown moe backend '{choice}', falling back to auto") warnings.warn(
f"Unknown moe backend '{choice}', falling back to auto", stacklevel=2
)
selected = MOEBackend.AUTO selected = MOEBackend.AUTO
if selected == MOEBackend.AUTO: if selected == MOEBackend.AUTO:
@@ -38,7 +40,8 @@ def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
return MOEBackend.NAIVE return MOEBackend.NAIVE
if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped(): if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped():
warnings.warn( warnings.warn(
"torch_grouped requested but torch>=2.8 not detected; falling back to naive" "torch_grouped requested but torch>=2.8 not detected; falling back to naive",
stacklevel=2,
) )
return MOEBackend.NAIVE return MOEBackend.NAIVE
return selected return selected

View File

@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Optional, Sequence, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -24,14 +24,31 @@ def available() -> bool:
def _stack_weights( def _stack_weights(
experts: Sequence[torch.nn.Module], names: Tuple[str, ...] experts_module,
names: Tuple[str, ...],
*,
key: str,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor: ) -> torch.Tensor:
stacked: List[torch.Tensor] = [] attr = f"_ax_grouped_{key}"
for expert in experts: cached = getattr(experts_module, attr, None)
mod = getattr(expert, "mlp", getattr(expert, "ffn", expert)) if cached is not None and cached.dtype == dtype and cached.device == device:
return cached
tensors: List[torch.Tensor] = []
for exp in experts_module:
mod = getattr(exp, "mlp", getattr(exp, "ffn", exp))
parts = [getattr(mod, name).weight.t() for name in names] parts = [getattr(mod, name).weight.t() for name in names]
stacked.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1)) tensors.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1))
return torch.stack(stacked, dim=0)
stacked = (
torch.stack(tensors, dim=0)
.to(device=device, dtype=dtype, non_blocking=True)
.contiguous()
)
setattr(experts_module, attr, stacked)
return stacked
def _call_grouped_mm( def _call_grouped_mm(
@@ -40,19 +57,22 @@ def _call_grouped_mm(
if not As: if not As:
return [] return []
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As] try:
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs] As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
device = As2[0].device Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32) device = As2[0].device
Y_cat = torch.ops.aten._grouped_mm( offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offs Y_cat = torch.ops.aten._grouped_mm(
) torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offs
outs: List[torch.Tensor] = [] )
start = 0 outs: List[torch.Tensor] = []
for m in offs.tolist(): start = 0
outs.append(Y_cat[start : start + m]) for m in offs.tolist():
start += m outs.append(Y_cat[start : start + m])
return outs start += m
return outs
except RuntimeError:
return None
def moe_ffn_forward_grouped( def moe_ffn_forward_grouped(
@@ -77,22 +97,27 @@ def moe_ffn_forward_grouped(
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
experts = list(experts_module) sample = getattr(
sample = getattr(experts[0], "mlp", getattr(experts[0], "ffn", experts[0])) experts_module[0], "mlp", getattr(experts_module[0], "ffn", experts_module[0])
)
if hasattr(sample, "w1") and hasattr(sample, "w3") and hasattr(sample, "w2"): if hasattr(sample, "w1") and hasattr(sample, "w3") and hasattr(sample, "w2"):
w13 = _stack_weights(experts, ("w1", "w3")).to( w13 = _stack_weights(
device=device, dtype=expert_dtype experts_module, ("w1", "w3"), key="w13", dtype=expert_dtype, device=device
)
w2 = _stack_weights(
experts_module, ("w2",), key="w2", dtype=expert_dtype, device=device
) )
w2 = _stack_weights(experts, ("w2",)).to(device=device, dtype=expert_dtype)
else: else:
names13 = ( names13 = (
("gate_up_proj",) ("gate_up_proj",)
if hasattr(sample, "gate_up_proj") if hasattr(sample, "gate_up_proj")
else ("up_proj", "gate_proj") else ("up_proj", "gate_proj")
) )
w13 = _stack_weights(experts, names13).to(device=device, dtype=expert_dtype) w13 = _stack_weights(
w2 = _stack_weights(experts, ("down_proj",)).to( experts_module, names13, key="w13", dtype=expert_dtype, device=device
device=device, dtype=expert_dtype )
w2 = _stack_weights(
experts_module, ("down_proj",), key="w2", dtype=expert_dtype, device=device
) )
flat_idx = topk_idx.view(-1) flat_idx = topk_idx.view(-1)
@@ -101,7 +126,7 @@ def moe_ffn_forward_grouped(
as_list: List[torch.Tensor] = [] as_list: List[torch.Tensor] = []
bs_list: List[torch.Tensor] = [] bs_list: List[torch.Tensor] = []
slices: List[Tuple[int, torch.Tensor]] = [] slices: List[Tuple[int, torch.Tensor]] = []
for i in range(len(experts)): for i, _ in enumerate(experts_module):
sel = flat_idx == i sel = flat_idx == i
if sel.any(): if sel.any():
as_list.append(x_rep[sel]) as_list.append(x_rep[sel])

View File

@@ -33,7 +33,8 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
and not _moe_backends._probe_torch_grouped() and not _moe_backends._probe_torch_grouped()
): ):
warnings.warn( warnings.warn(
"torch_grouped selected but not available; falling back to naive" "torch_grouped selected but not available; falling back to naive",
stacklevel=2,
) )
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)