This commit is contained in:
Dan Saunders
2025-09-19 11:34:27 -04:00
parent 3bfed0aac8
commit 63544ce709
3 changed files with 144 additions and 81 deletions

View File

@@ -1,11 +1,18 @@
#!/usr/bin/env python #!/usr/bin/env python
import argparse import argparse
import sys
import time import time
from pathlib import Path
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception: # pragma: no cover - fallback when torch_grouped unavailable
tg = None
class SwiGLUMlp(nn.Module): class SwiGLUMlp(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int): def __init__(self, hidden_size: int, intermediate_size: int):
@@ -51,10 +58,10 @@ def forward_naive(
return y.view(bsz, seqlen, hdim) return y.view(bsz, seqlen, hdim)
def bench(fn, *args, iters=50, warmup=10, sync=True): def bench(fn, iters=50, warmup=10, sync=True):
# warmup # warmup
for _ in range(warmup): for _ in range(warmup):
fn(*args) fn()
if sync and torch.cuda.is_available(): if sync and torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()
# measure # measure
@@ -63,7 +70,7 @@ def bench(fn, *args, iters=50, warmup=10, sync=True):
if sync and torch.cuda.is_available(): if sync and torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.perf_counter() t0 = time.perf_counter()
fn(*args) fn()
if sync and torch.cuda.is_available(): if sync and torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()
dt = (time.perf_counter() - t0) * 1000.0 dt = (time.perf_counter() - t0) * 1000.0
@@ -94,6 +101,18 @@ def main():
) )
p.add_argument("--iters", type=int, default=50) p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10) p.add_argument("--warmup", type=int, default=10)
p.add_argument(
"--hf-block",
type=str,
default="none",
choices=["none", "qwen2_moe"],
help="Use a Hugging Face MoE block for benchmarking instead of the toy SwiGLU layer.",
)
p.add_argument(
"--profile",
action="store_true",
help="Capture CUDA profiler tables for naive and grouped runs.",
)
args = p.parse_args() args = p.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -107,101 +126,136 @@ def main():
if device == "cuda": if device == "cuda":
torch.cuda.manual_seed(0) torch.cuda.manual_seed(0)
# Model
experts = Experts(args.experts, args.hidden, args.inter).to(
device=device, dtype=dtype
)
gate = nn.Linear(args.hidden, args.experts, bias=False).to(
device=device, dtype=dtype
)
# data # data
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype) x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
# Report config
tokens = args.bsz * args.seq tokens = args.bsz * args.seq
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} experts={args.experts} top_k={args.top_k}"
)
# Naive baseline use_hf = args.hf_block != "none"
t_naive = bench(
forward_naive, if use_hf:
x, project_root = Path(__file__).resolve().parents[2]
gate, transformers_src = project_root / "transformers" / "src"
experts, if transformers_src.exists() and str(transformers_src) not in sys.path:
args.top_k, sys.path.append(str(transformers_src))
iters=args.iters,
warmup=args.warmup, if args.hf_block == "qwen2_moe":
) from transformers.models.qwen2_moe.configuration_qwen2_moe import (
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k) Qwen2MoeConfig,
)
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
Qwen2MoeSparseMoeBlock,
)
cfg = Qwen2MoeConfig(
hidden_size=args.hidden,
moe_intermediate_size=args.inter,
shared_expert_intermediate_size=args.inter,
num_experts=args.experts,
num_experts_per_tok=args.top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block_naive = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block_naive.state_dict())
def run_naive_model(inp: torch.Tensor) -> torch.Tensor:
out, _ = block_naive(inp)
return out
def run_grouped_model(inp: torch.Tensor) -> torch.Tensor:
if tg is None or not tg.available():
return torch.empty(0)
block_grouped.experts._ax_parent_block = block_grouped
y, _ = tg.moe_ffn_forward_grouped(
inp, block_grouped.gate, block_grouped.experts, block_grouped.top_k
)
return y if y is not None else torch.empty(0)
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
f"experts={args.experts} top_k={args.top_k} hf_block={args.hf_block}"
)
else:
experts = Experts(args.experts, args.hidden, args.inter).to(
device=device, dtype=dtype
)
gate = nn.Linear(args.hidden, args.experts, bias=False).to(
device=device, dtype=dtype
)
def run_naive_model(inp: torch.Tensor) -> torch.Tensor:
return forward_naive(inp, gate, experts, args.top_k)
def run_grouped_model(inp: torch.Tensor) -> torch.Tensor:
if tg is None or not tg.available():
return torch.empty(0)
y, _ = tg.moe_ffn_forward_grouped(inp, gate, experts, args.top_k)
return y if y is not None else torch.empty(0)
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
f"experts={args.experts} top_k={args.top_k}"
)
# Benchmark naive
t_naive = bench(lambda: run_naive_model(x), iters=args.iters, warmup=args.warmup)
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12) tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
print( print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s" f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
) )
# Prepare reference output once for checks
with torch.no_grad(): with torch.no_grad():
y_ref = forward_naive(x, gate, experts, args.top_k) y_ref = run_naive_model(x)
# torch_grouped backend (PyTorch 2.8+) # Benchmark grouped
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception:
tg = None
if tg is not None and tg.available(): if tg is not None and tg.available():
y_grouped = run_grouped_model(x)
def forward_tg(a, g, ex, topk): if y_grouped.numel() == 0:
y, _ = tg.moe_ffn_forward_grouped(a, g, ex, topk) print("torch_grouped\tN/A (op not callable)")
return y else:
t_grouped = bench(
y_tg = forward_tg(x, gate, experts, args.top_k) lambda: run_grouped_model(x),
if y_tg is not None:
t_ms = bench(
forward_tg,
x,
gate,
experts,
args.top_k,
iters=args.iters, iters=args.iters,
warmup=args.warmup, warmup=args.warmup,
) )
tflops = flops_total / ((t_ms / 1000.0) * 1e12) tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12)
speedup = t_naive / t_ms speedup = t_naive / t_grouped
print( print(
f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×" f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000):.1f} tok/s\t"
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
)
diff = (y_ref.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (
(diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item()
)
print(
f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
) )
with torch.no_grad():
y_fast = y_tg
y_ref32 = y_ref.float()
y_fast32 = y_fast.float()
diff = (y_ref32 - y_fast32).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (
(diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
)
print(
f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
else:
print("torch_grouped\tN/A (op not callable)")
else: else:
print("torch_grouped\tN/A (unavailable)") print("torch_grouped\tN/A (unavailable)")
with torch.profiler.profile( if args.profile and tg is not None and tg.available():
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True with torch.profiler.profile(
) as prof: activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
forward_naive(x, gate, experts, args.top_k) ) as prof:
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) run_naive_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
with torch.profiler.profile( with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True, record_shapes=True,
with_stack=False, with_stack=False,
) as prof: ) as prof:
forward_tg(x, gate, experts, args.top_k) run_grouped_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -267,9 +267,12 @@ def moe_ffn_forward_grouped(
) )
return None, None return None, None
expert_impls = _iter_expert_impls(experts_module) parent_block = getattr(experts_module, "_ax_parent_block", None)
expert_container = getattr(experts_module, "experts", experts_module)
expert_impls = _iter_expert_impls(expert_container)
sample_mod = expert_impls[0] sample_mod = expert_impls[0]
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod) storage = _ensure_grouped_weights(expert_container, expert_impls, sample_mod)
w_gate = storage.gate w_gate = storage.gate
w_up = storage.up w_up = storage.up
w2 = storage.down w2 = storage.down
@@ -278,11 +281,12 @@ def moe_ffn_forward_grouped(
router_logits = gate_linear(x_flat.to(routing_dtype)) router_logits = gate_linear(x_flat.to(routing_dtype))
shared_out_flat: Optional[torch.Tensor] = None shared_out_flat: Optional[torch.Tensor] = None
if hasattr(experts_module, "shared_expert"): shared_owner = parent_block if parent_block is not None else experts_module
shared_expert = experts_module.shared_expert if hasattr(shared_owner, "shared_expert"):
shared_expert = shared_owner.shared_expert
shared_out_flat = shared_expert(x_flat) shared_out_flat = shared_expert(x_flat)
shared_out_flat = shared_out_flat.to(expert_dtype) shared_out_flat = shared_out_flat.to(expert_dtype)
shared_gate = getattr(experts_module, "shared_expert_gate", None) shared_gate = getattr(shared_owner, "shared_expert_gate", None)
if shared_gate is not None: if shared_gate is not None:
gate_input = shared_gate(x_flat.to(shared_gate.weight.dtype)) gate_input = shared_gate(x_flat.to(shared_gate.weight.dtype))
gate_vals = torch.sigmoid(gate_input) gate_vals = torch.sigmoid(gate_input)

View File

@@ -76,6 +76,11 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
@wraps(orig_forward) @wraps(orig_forward)
def _grouped_forward(self, hidden_states: torch.Tensor, *args, **kwargs): def _grouped_forward(self, hidden_states: torch.Tensor, *args, **kwargs):
bsz, seqlen, hdim = hidden_states.shape bsz, seqlen, hdim = hidden_states.shape
# expose parent block so grouped backend can access shared expert context
try:
self.experts._ax_parent_block = self
except Exception:
pass
y, router_logits = _tg.moe_ffn_forward_grouped( y, router_logits = _tg.moe_ffn_forward_grouped(
hidden_states, self.gate, self.experts, self.top_k hidden_states, self.gate, self.experts, self.top_k
) )