Compare commits

...

10 Commits

Author SHA1 Message Date
Dan Saunders
8564961423 fix compile 2025-09-19 13:59:57 -04:00
Dan Saunders
ce21da9177 fix compile 2025-09-19 13:55:54 -04:00
Dan Saunders
b5dc58373f fix compile 2025-09-19 13:52:42 -04:00
Dan Saunders
7327144344 compile 2025-09-19 13:41:12 -04:00
Dan Saunders
fb11f696e9 bench sweep 2025-09-19 13:24:40 -04:00
Dan Saunders
105c817b0b default fix 2025-09-19 16:59:20 +00:00
Dan Saunders
64345e7707 recurse fix 2025-09-19 12:58:58 -04:00
Dan Saunders
0f8b921399 contig 2025-09-19 12:47:53 -04:00
Dan Saunders
336616d659 defaults 2025-09-19 16:45:39 +00:00
Dan Saunders
d2f1e23bcd fix 2025-09-19 12:45:18 -04:00
4 changed files with 375 additions and 14 deletions

View File

@@ -6,9 +6,11 @@ from __future__ import annotations
import argparse import argparse
import sys import sys
import time import time
import weakref
from pathlib import Path from pathlib import Path
import torch import torch
import torch._dynamo as dynamo
try: try:
from axolotl.kernels.moe import torch_grouped as tg from axolotl.kernels.moe import torch_grouped as tg
@@ -83,6 +85,11 @@ def main() -> None:
p.add_argument("--iters", type=int, default=50) p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10) p.add_argument("--warmup", type=int, default=10)
p.add_argument("--profile", action="store_true") p.add_argument("--profile", action="store_true")
p.add_argument(
"--compile",
action="store_true",
help="Torch.compile both paths before benchmarking",
)
args = p.parse_args() args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -114,17 +121,41 @@ def main() -> None:
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype) x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
def run_naive(): # Optional torch.compile
y, _ = block_naive(x) run_grouped_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover
print(f"torch.compile naive failed ({exc}); using eager")
else:
def grouped_forward(inp, *, block=block_grouped):
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
inp, block.gate, block.experts, block.top_k
)
return y
try:
run_grouped_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover
print(f"torch.compile grouped failed ({exc}); using eager")
run_grouped_impl = None
def run_naive(block=block_naive, data=x):
y, _ = block(data)
return y return y
def run_grouped(): def run_grouped(block=block_grouped, data=x, impl=run_grouped_impl):
if impl is not None:
return impl(data)
if tg is None or not tg.available(): if tg is None or not tg.available():
return torch.empty(0) return torch.empty(0)
block_grouped.experts._ax_parent_block = block_grouped block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped( y, _ = tg.moe_ffn_forward_grouped(data, block.gate, block.experts, block.top_k)
x, block_grouped.gate, block_grouped.experts, block_grouped.top_k
)
return y if y is not None else torch.empty(0) return y if y is not None else torch.empty(0)
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup) t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)

311
scripts/bench_moe_sweep.py Normal file
View File

@@ -0,0 +1,311 @@
#!/usr/bin/env python
"""Sweep grouped_mm vs naive performance for Qwen2 MoE block."""
from __future__ import annotations
import argparse
import csv
import sys
import time
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import List
import torch
import torch._dynamo as dynamo
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception: # pragma: no cover
tg = None
def _parse_list(arg: str) -> List[int]:
return [int(v) for v in arg.split(",") if v]
def _bench(run, *, iters: int, warmup: int, device: torch.device) -> float:
for _ in range(warmup):
run()
if device.type == "cuda":
torch.cuda.synchronize()
times: List[float] = []
for _ in range(iters):
if device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
run()
if device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def _load_block(
hidden: int,
inter: int,
experts: int,
top_k: int,
*,
device: torch.device,
dtype: torch.dtype,
):
project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
sys.path.append(str(transformers_src))
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
cfg = Qwen2MoeConfig(
hidden_size=hidden,
moe_intermediate_size=inter,
shared_expert_intermediate_size=inter,
num_experts=experts,
num_experts_per_tok=top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block.state_dict())
return block, block_grouped
@dataclass
class Result:
bsz: int
seq: int
hidden: int
inter: int
experts: int
top_k: int
dtype: str
naive_ms: float
grouped_ms: float
speedup: float
naive_tflops: float
grouped_tflops: float
max_abs: float
mean_abs: float
rel_l2: float
def main() -> None:
p = argparse.ArgumentParser(description="Grouped MoE sweep")
p.add_argument("--batch-sizes", default="4,8,16")
p.add_argument("--seq-lens", default="512,1024,2048")
p.add_argument("--hidden", default="2048,4096")
p.add_argument("--inter", default="5632,8192,14336")
p.add_argument("--experts", default="8,16,32")
p.add_argument("--top-k", default="1,2,4")
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=25)
p.add_argument("--warmup", type=int, default=5)
p.add_argument("--csv", type=Path, default=None)
p.add_argument("--compile", action="store_true")
args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[args.dtype]
if tg is None or not tg.available():
print("torch_grouped unavailable; sweep aborted")
return
bs_list = _parse_list(args.batch_sizes)
seq_list = _parse_list(args.seq_lens)
hidden_list = _parse_list(args.hidden)
inter_list = _parse_list(args.inter)
expert_list = _parse_list(args.experts)
topk_list = _parse_list(args.top_k)
results: List[Result] = []
print(
"bsz\tseq\thidden\tinter\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
)
for bsz in bs_list:
for seq in seq_list:
tokens = bsz * seq
for hidden in hidden_list:
for inter in inter_list:
for experts in expert_list:
for top_k in topk_list:
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
block_naive, block_grouped = _load_block(
hidden,
inter,
experts,
top_k,
device=device,
dtype=dtype,
)
x = torch.randn(
bsz, seq, hidden, device=device, dtype=dtype
)
compiled_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc:
print(
f"torch.compile naive failed ({exc}); using eager"
)
else:
def grouped_forward(inp, *, block=block_grouped):
block.experts._ax_parent_block_ref = (
weakref.ref(block)
) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
inp,
block.gate,
block.experts,
block.top_k,
)
return y
try:
compiled_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
except Exception as exc:
print(
f"torch.compile grouped failed ({exc}); using eager"
)
compiled_impl = None
def run_naive(block=block_naive, data=x):
y, _ = block(data)
return y
def run_grouped(
block=block_grouped, data=x, impl=compiled_impl
):
if impl is not None:
return impl(data)
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
data,
block.gate,
block.experts,
block.top_k,
)
return y
naive_ms = _bench(
run_naive,
iters=args.iters,
warmup=args.warmup,
device=device,
)
y_naive = run_naive()
grouped_ms = _bench(
run_grouped,
iters=args.iters,
warmup=args.warmup,
device=device,
)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
res = Result(
bsz,
seq,
hidden,
inter,
experts,
top_k,
args.dtype,
naive_ms,
grouped_ms,
naive_ms / grouped_ms,
_estimate_flops(tokens, hidden, inter, top_k)
/ ((naive_ms / 1000.0) * 1e12),
_estimate_flops(tokens, hidden, inter, top_k)
/ ((grouped_ms / 1000.0) * 1e12),
diff.max().item(),
diff.mean().item(),
(
(
diff.pow(2).sum()
/ (y_naive.float().pow(2).sum() + 1e-12)
)
.sqrt()
.item()
),
)
results.append(res)
print(
f"{bsz}\t{seq}\t{hidden}\t{inter}\t{experts}\t{top_k}\t{res.naive_ms:.2f}\t"
f"{res.grouped_ms:.2f}\t{res.speedup:.2f}\t{res.naive_tflops:.2f}\t"
f"{res.grouped_tflops:.2f}\t{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
)
if args.csv:
fieldnames = [
"bsz",
"seq",
"hidden",
"inter",
"experts",
"top_k",
"dtype",
"naive_ms",
"grouped_ms",
"speedup",
"naive_tflops",
"grouped_tflops",
"max_abs",
"mean_abs",
"rel_l2",
]
with args.csv.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for r in results:
writer.writerow(
{
"bsz": r.bsz,
"seq": r.seq,
"hidden": r.hidden,
"inter": r.inter,
"experts": r.experts,
"top_k": r.top_k,
"dtype": r.dtype,
"naive_ms": f"{r.naive_ms:.4f}",
"grouped_ms": f"{r.grouped_ms:.4f}",
"speedup": f"{r.speedup:.4f}",
"naive_tflops": f"{r.naive_tflops:.4f}",
"grouped_tflops": f"{r.grouped_tflops:.4f}",
"max_abs": f"{r.max_abs:.6e}",
"mean_abs": f"{r.mean_abs:.6e}",
"rel_l2": f"{r.rel_l2:.6e}",
}
)
if __name__ == "__main__":
import weakref
main()

View File

@@ -27,7 +27,16 @@ def available() -> bool:
return False return False
def _iter_expert_impls(experts_module) -> List[torch.nn.Module]: def _iter_expert_impls(
experts_module, visited: Optional[set[int]] = None
) -> List[torch.nn.Module]:
if visited is None:
visited = set()
module_id = id(experts_module)
if module_id in visited:
return []
visited.add(module_id)
impls: List[torch.nn.Module] = [] impls: List[torch.nn.Module] = []
for exp in experts_module: for exp in experts_module:
candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp)) candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp))
@@ -36,7 +45,7 @@ def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
continue continue
nested = getattr(candidate, "experts", None) nested = getattr(candidate, "experts", None)
if nested is not None: if nested is not None:
impls.extend(_iter_expert_impls(nested)) impls.extend(_iter_expert_impls(nested, visited))
continue continue
raise RuntimeError( raise RuntimeError(
"torch_grouped: unable to resolve expert implementation for module" "torch_grouped: unable to resolve expert implementation for module"
@@ -277,7 +286,14 @@ def moe_ffn_forward_grouped(
) )
return None, None return None, None
parent_block = getattr(experts_module, "_ax_parent_block", None) parent_block = None
parent_ref = getattr(experts_module, "_ax_parent_block_ref", None)
if parent_ref is not None:
try:
parent_block = parent_ref()
except TypeError:
parent_block = None
expert_container = getattr(experts_module, "experts", experts_module) expert_container = getattr(experts_module, "experts", experts_module)
expert_impls = _iter_expert_impls(expert_container) expert_impls = _iter_expert_impls(expert_container)
@@ -326,19 +342,21 @@ def moe_ffn_forward_grouped(
counts_i32 = assignments.to(device=device, dtype=torch.int32) counts_i32 = assignments.to(device=device, dtype=torch.int32)
offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32) offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32)
if offsets[-1].item() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype
routed_in = routed_input.to(mm_dtype) routed_in = routed_input.to(mm_dtype)
w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype) w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)
w_up_t = w_up.transpose(-2, -1).to(mm_dtype) w_up_t = w_up.transpose(-2, -1).to(mm_dtype)
w2_t = w2.transpose(-2, -1).to(mm_dtype) w2_t = w2.transpose(-2, -1).to(mm_dtype)
routed_in = routed_in.contiguous()
w_gate_t = w_gate_t.contiguous()
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets) gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
torch.ops.aten.silu_(gate_out) torch.ops.aten.silu_(gate_out)
w_up_t = w_up_t.contiguous()
up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets) up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
gate_out.mul_(up_out) gate_out.mul_(up_out)
gate_out = gate_out.contiguous()
w2_t = w2_t.contiguous()
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype) down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype)
weights = scores_sorted.unsqueeze(-1).to(expert_dtype) weights = scores_sorted.unsqueeze(-1).to(expert_dtype)

View File

@@ -1,4 +1,5 @@
import logging import logging
import weakref
from functools import wraps from functools import wraps
import torch import torch
@@ -78,7 +79,7 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
bsz, seqlen, hdim = hidden_states.shape bsz, seqlen, hdim = hidden_states.shape
# expose parent block so grouped backend can access shared expert context # expose parent block so grouped backend can access shared expert context
try: try:
self.experts._ax_parent_block = self self.experts._ax_parent_block_ref = weakref.ref(self)
except Exception: except Exception:
pass pass
y, router_logits = _tg.moe_ffn_forward_grouped( y, router_logits = _tg.moe_ffn_forward_grouped(