This commit is contained in:
Dan Saunders
2025-09-19 11:34:27 -04:00
parent 3bfed0aac8
commit 63544ce709
3 changed files with 144 additions and 81 deletions

View File

@@ -1,11 +1,18 @@
#!/usr/bin/env python
import argparse
import sys
import time
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception: # pragma: no cover - fallback when torch_grouped unavailable
tg = None
class SwiGLUMlp(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
@@ -51,10 +58,10 @@ def forward_naive(
return y.view(bsz, seqlen, hdim)
def bench(fn, *args, iters=50, warmup=10, sync=True):
def bench(fn, iters=50, warmup=10, sync=True):
# warmup
for _ in range(warmup):
fn(*args)
fn()
if sync and torch.cuda.is_available():
torch.cuda.synchronize()
# measure
@@ -63,7 +70,7 @@ def bench(fn, *args, iters=50, warmup=10, sync=True):
if sync and torch.cuda.is_available():
torch.cuda.synchronize()
t0 = time.perf_counter()
fn(*args)
fn()
if sync and torch.cuda.is_available():
torch.cuda.synchronize()
dt = (time.perf_counter() - t0) * 1000.0
@@ -94,6 +101,18 @@ def main():
)
p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10)
p.add_argument(
"--hf-block",
type=str,
default="none",
choices=["none", "qwen2_moe"],
help="Use a Hugging Face MoE block for benchmarking instead of the toy SwiGLU layer.",
)
p.add_argument(
"--profile",
action="store_true",
help="Capture CUDA profiler tables for naive and grouped runs.",
)
args = p.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -107,101 +126,136 @@ def main():
if device == "cuda":
torch.cuda.manual_seed(0)
# Model
experts = Experts(args.experts, args.hidden, args.inter).to(
device=device, dtype=dtype
)
gate = nn.Linear(args.hidden, args.experts, bias=False).to(
device=device, dtype=dtype
)
# data
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
# Report config
tokens = args.bsz * args.seq
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} experts={args.experts} top_k={args.top_k}"
)
# Naive baseline
t_naive = bench(
forward_naive,
x,
gate,
experts,
args.top_k,
iters=args.iters,
warmup=args.warmup,
)
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
use_hf = args.hf_block != "none"
if use_hf:
project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
sys.path.append(str(transformers_src))
if args.hf_block == "qwen2_moe":
from transformers.models.qwen2_moe.configuration_qwen2_moe import (
Qwen2MoeConfig,
)
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
Qwen2MoeSparseMoeBlock,
)
cfg = Qwen2MoeConfig(
hidden_size=args.hidden,
moe_intermediate_size=args.inter,
shared_expert_intermediate_size=args.inter,
num_experts=args.experts,
num_experts_per_tok=args.top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block_naive = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block_naive.state_dict())
def run_naive_model(inp: torch.Tensor) -> torch.Tensor:
out, _ = block_naive(inp)
return out
def run_grouped_model(inp: torch.Tensor) -> torch.Tensor:
if tg is None or not tg.available():
return torch.empty(0)
block_grouped.experts._ax_parent_block = block_grouped
y, _ = tg.moe_ffn_forward_grouped(
inp, block_grouped.gate, block_grouped.experts, block_grouped.top_k
)
return y if y is not None else torch.empty(0)
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
f"experts={args.experts} top_k={args.top_k} hf_block={args.hf_block}"
)
else:
experts = Experts(args.experts, args.hidden, args.inter).to(
device=device, dtype=dtype
)
gate = nn.Linear(args.hidden, args.experts, bias=False).to(
device=device, dtype=dtype
)
def run_naive_model(inp: torch.Tensor) -> torch.Tensor:
return forward_naive(inp, gate, experts, args.top_k)
def run_grouped_model(inp: torch.Tensor) -> torch.Tensor:
if tg is None or not tg.available():
return torch.empty(0)
y, _ = tg.moe_ffn_forward_grouped(inp, gate, experts, args.top_k)
return y if y is not None else torch.empty(0)
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
f"experts={args.experts} top_k={args.top_k}"
)
# Benchmark naive
t_naive = bench(lambda: run_naive_model(x), iters=args.iters, warmup=args.warmup)
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
)
# Prepare reference output once for checks
with torch.no_grad():
y_ref = forward_naive(x, gate, experts, args.top_k)
y_ref = run_naive_model(x)
# torch_grouped backend (PyTorch 2.8+)
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception:
tg = None
# Benchmark grouped
if tg is not None and tg.available():
def forward_tg(a, g, ex, topk):
y, _ = tg.moe_ffn_forward_grouped(a, g, ex, topk)
return y
y_tg = forward_tg(x, gate, experts, args.top_k)
if y_tg is not None:
t_ms = bench(
forward_tg,
x,
gate,
experts,
args.top_k,
y_grouped = run_grouped_model(x)
if y_grouped.numel() == 0:
print("torch_grouped\tN/A (op not callable)")
else:
t_grouped = bench(
lambda: run_grouped_model(x),
iters=args.iters,
warmup=args.warmup,
)
tflops = flops_total / ((t_ms / 1000.0) * 1e12)
speedup = t_naive / t_ms
tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12)
speedup = t_naive / t_grouped
print(
f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000):.1f} tok/s\t"
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
)
diff = (y_ref.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (
(diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item()
)
print(
f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
with torch.no_grad():
y_fast = y_tg
y_ref32 = y_ref.float()
y_fast32 = y_fast.float()
diff = (y_ref32 - y_fast32).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (
(diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
)
print(
f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
else:
print("torch_grouped\tN/A (op not callable)")
else:
print("torch_grouped\tN/A (unavailable)")
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
) as prof:
forward_naive(x, gate, experts, args.top_k)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
if args.profile and tg is not None and tg.available():
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
) as prof:
run_naive_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
with_stack=False,
) as prof:
forward_tg(x, gate, experts, args.top_k)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
with_stack=False,
) as prof:
run_grouped_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
if __name__ == "__main__":

View File

@@ -267,9 +267,12 @@ def moe_ffn_forward_grouped(
)
return None, None
expert_impls = _iter_expert_impls(experts_module)
parent_block = getattr(experts_module, "_ax_parent_block", None)
expert_container = getattr(experts_module, "experts", experts_module)
expert_impls = _iter_expert_impls(expert_container)
sample_mod = expert_impls[0]
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod)
storage = _ensure_grouped_weights(expert_container, expert_impls, sample_mod)
w_gate = storage.gate
w_up = storage.up
w2 = storage.down
@@ -278,11 +281,12 @@ def moe_ffn_forward_grouped(
router_logits = gate_linear(x_flat.to(routing_dtype))
shared_out_flat: Optional[torch.Tensor] = None
if hasattr(experts_module, "shared_expert"):
shared_expert = experts_module.shared_expert
shared_owner = parent_block if parent_block is not None else experts_module
if hasattr(shared_owner, "shared_expert"):
shared_expert = shared_owner.shared_expert
shared_out_flat = shared_expert(x_flat)
shared_out_flat = shared_out_flat.to(expert_dtype)
shared_gate = getattr(experts_module, "shared_expert_gate", None)
shared_gate = getattr(shared_owner, "shared_expert_gate", None)
if shared_gate is not None:
gate_input = shared_gate(x_flat.to(shared_gate.weight.dtype))
gate_vals = torch.sigmoid(gate_input)

View File

@@ -76,6 +76,11 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
@wraps(orig_forward)
def _grouped_forward(self, hidden_states: torch.Tensor, *args, **kwargs):
bsz, seqlen, hdim = hidden_states.shape
# expose parent block so grouped backend can access shared expert context
try:
self.experts._ax_parent_block = self
except Exception:
pass
y, router_logits = _tg.moe_ffn_forward_grouped(
hidden_states, self.gate, self.experts, self.top_k
)