fix perf degradation
This commit is contained in:
@@ -29,7 +29,9 @@ def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
|
|||||||
try:
|
try:
|
||||||
selected = MOEBackend(choice)
|
selected = MOEBackend(choice)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
warnings.warn(f"Unknown moe backend '{choice}', falling back to auto")
|
warnings.warn(
|
||||||
|
f"Unknown moe backend '{choice}', falling back to auto", stacklevel=2
|
||||||
|
)
|
||||||
selected = MOEBackend.AUTO
|
selected = MOEBackend.AUTO
|
||||||
|
|
||||||
if selected == MOEBackend.AUTO:
|
if selected == MOEBackend.AUTO:
|
||||||
@@ -38,7 +40,8 @@ def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
|
|||||||
return MOEBackend.NAIVE
|
return MOEBackend.NAIVE
|
||||||
if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped():
|
if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped():
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"torch_grouped requested but torch>=2.8 not detected; falling back to naive"
|
"torch_grouped requested but torch>=2.8 not detected; falling back to naive",
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
return MOEBackend.NAIVE
|
return MOEBackend.NAIVE
|
||||||
return selected
|
return selected
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import List, Optional, Sequence, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -24,14 +24,31 @@ def available() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _stack_weights(
|
def _stack_weights(
|
||||||
experts: Sequence[torch.nn.Module], names: Tuple[str, ...]
|
experts_module,
|
||||||
|
names: Tuple[str, ...],
|
||||||
|
*,
|
||||||
|
key: str,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
stacked: List[torch.Tensor] = []
|
attr = f"_ax_grouped_{key}"
|
||||||
for expert in experts:
|
cached = getattr(experts_module, attr, None)
|
||||||
mod = getattr(expert, "mlp", getattr(expert, "ffn", expert))
|
if cached is not None and cached.dtype == dtype and cached.device == device:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
tensors: List[torch.Tensor] = []
|
||||||
|
for exp in experts_module:
|
||||||
|
mod = getattr(exp, "mlp", getattr(exp, "ffn", exp))
|
||||||
parts = [getattr(mod, name).weight.t() for name in names]
|
parts = [getattr(mod, name).weight.t() for name in names]
|
||||||
stacked.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1))
|
tensors.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1))
|
||||||
return torch.stack(stacked, dim=0)
|
|
||||||
|
stacked = (
|
||||||
|
torch.stack(tensors, dim=0)
|
||||||
|
.to(device=device, dtype=dtype, non_blocking=True)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
setattr(experts_module, attr, stacked)
|
||||||
|
return stacked
|
||||||
|
|
||||||
|
|
||||||
def _call_grouped_mm(
|
def _call_grouped_mm(
|
||||||
@@ -40,19 +57,22 @@ def _call_grouped_mm(
|
|||||||
if not As:
|
if not As:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
try:
|
||||||
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
||||||
device = As2[0].device
|
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
||||||
offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
|
device = As2[0].device
|
||||||
Y_cat = torch.ops.aten._grouped_mm(
|
offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
|
||||||
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offs
|
Y_cat = torch.ops.aten._grouped_mm(
|
||||||
)
|
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offs
|
||||||
outs: List[torch.Tensor] = []
|
)
|
||||||
start = 0
|
outs: List[torch.Tensor] = []
|
||||||
for m in offs.tolist():
|
start = 0
|
||||||
outs.append(Y_cat[start : start + m])
|
for m in offs.tolist():
|
||||||
start += m
|
outs.append(Y_cat[start : start + m])
|
||||||
return outs
|
start += m
|
||||||
|
return outs
|
||||||
|
except RuntimeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def moe_ffn_forward_grouped(
|
def moe_ffn_forward_grouped(
|
||||||
@@ -77,22 +97,27 @@ def moe_ffn_forward_grouped(
|
|||||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||||
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
experts = list(experts_module)
|
sample = getattr(
|
||||||
sample = getattr(experts[0], "mlp", getattr(experts[0], "ffn", experts[0]))
|
experts_module[0], "mlp", getattr(experts_module[0], "ffn", experts_module[0])
|
||||||
|
)
|
||||||
if hasattr(sample, "w1") and hasattr(sample, "w3") and hasattr(sample, "w2"):
|
if hasattr(sample, "w1") and hasattr(sample, "w3") and hasattr(sample, "w2"):
|
||||||
w13 = _stack_weights(experts, ("w1", "w3")).to(
|
w13 = _stack_weights(
|
||||||
device=device, dtype=expert_dtype
|
experts_module, ("w1", "w3"), key="w13", dtype=expert_dtype, device=device
|
||||||
|
)
|
||||||
|
w2 = _stack_weights(
|
||||||
|
experts_module, ("w2",), key="w2", dtype=expert_dtype, device=device
|
||||||
)
|
)
|
||||||
w2 = _stack_weights(experts, ("w2",)).to(device=device, dtype=expert_dtype)
|
|
||||||
else:
|
else:
|
||||||
names13 = (
|
names13 = (
|
||||||
("gate_up_proj",)
|
("gate_up_proj",)
|
||||||
if hasattr(sample, "gate_up_proj")
|
if hasattr(sample, "gate_up_proj")
|
||||||
else ("up_proj", "gate_proj")
|
else ("up_proj", "gate_proj")
|
||||||
)
|
)
|
||||||
w13 = _stack_weights(experts, names13).to(device=device, dtype=expert_dtype)
|
w13 = _stack_weights(
|
||||||
w2 = _stack_weights(experts, ("down_proj",)).to(
|
experts_module, names13, key="w13", dtype=expert_dtype, device=device
|
||||||
device=device, dtype=expert_dtype
|
)
|
||||||
|
w2 = _stack_weights(
|
||||||
|
experts_module, ("down_proj",), key="w2", dtype=expert_dtype, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
flat_idx = topk_idx.view(-1)
|
flat_idx = topk_idx.view(-1)
|
||||||
@@ -101,7 +126,7 @@ def moe_ffn_forward_grouped(
|
|||||||
as_list: List[torch.Tensor] = []
|
as_list: List[torch.Tensor] = []
|
||||||
bs_list: List[torch.Tensor] = []
|
bs_list: List[torch.Tensor] = []
|
||||||
slices: List[Tuple[int, torch.Tensor]] = []
|
slices: List[Tuple[int, torch.Tensor]] = []
|
||||||
for i in range(len(experts)):
|
for i, _ in enumerate(experts_module):
|
||||||
sel = flat_idx == i
|
sel = flat_idx == i
|
||||||
if sel.any():
|
if sel.any():
|
||||||
as_list.append(x_rep[sel])
|
as_list.append(x_rep[sel])
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
|
|||||||
and not _moe_backends._probe_torch_grouped()
|
and not _moe_backends._probe_torch_grouped()
|
||||||
):
|
):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"torch_grouped selected but not available; falling back to naive"
|
"torch_grouped selected but not available; falling back to naive",
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
|
|||||||
Reference in New Issue
Block a user