From 63544ce709a5113a3a78ba4440b4e308d8d9311f Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 19 Sep 2025 11:34:27 -0400 Subject: [PATCH] fix --- scripts/bench_moe.py | 206 ++++++++++++++--------- src/axolotl/kernels/moe/torch_grouped.py | 14 +- src/axolotl/monkeypatch/moe_grouped.py | 5 + 3 files changed, 144 insertions(+), 81 deletions(-) diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index 76c3929c2..3cb434edb 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -1,11 +1,18 @@ #!/usr/bin/env python import argparse +import sys import time +from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F +try: + from axolotl.kernels.moe import torch_grouped as tg +except Exception: # pragma: no cover - fallback when torch_grouped unavailable + tg = None + class SwiGLUMlp(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int): @@ -51,10 +58,10 @@ def forward_naive( return y.view(bsz, seqlen, hdim) -def bench(fn, *args, iters=50, warmup=10, sync=True): +def bench(fn, iters=50, warmup=10, sync=True): # warmup for _ in range(warmup): - fn(*args) + fn() if sync and torch.cuda.is_available(): torch.cuda.synchronize() # measure @@ -63,7 +70,7 @@ def bench(fn, *args, iters=50, warmup=10, sync=True): if sync and torch.cuda.is_available(): torch.cuda.synchronize() t0 = time.perf_counter() - fn(*args) + fn() if sync and torch.cuda.is_available(): torch.cuda.synchronize() dt = (time.perf_counter() - t0) * 1000.0 @@ -94,6 +101,18 @@ def main(): ) p.add_argument("--iters", type=int, default=50) p.add_argument("--warmup", type=int, default=10) + p.add_argument( + "--hf-block", + type=str, + default="none", + choices=["none", "qwen2_moe"], + help="Use a Hugging Face MoE block for benchmarking instead of the toy SwiGLU layer.", + ) + p.add_argument( + "--profile", + action="store_true", + help="Capture CUDA profiler tables for naive and grouped runs.", + ) args = p.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" @@ -107,101 +126,136 @@ def main(): if device == "cuda": torch.cuda.manual_seed(0) - # Model - experts = Experts(args.experts, args.hidden, args.inter).to( - device=device, dtype=dtype - ) - gate = nn.Linear(args.hidden, args.experts, bias=False).to( - device=device, dtype=dtype - ) - # data x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype) - # Report config tokens = args.bsz * args.seq - print( - f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} experts={args.experts} top_k={args.top_k}" - ) - # Naive baseline - t_naive = bench( - forward_naive, - x, - gate, - experts, - args.top_k, - iters=args.iters, - warmup=args.warmup, - ) - flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k) + use_hf = args.hf_block != "none" + + if use_hf: + project_root = Path(__file__).resolve().parents[2] + transformers_src = project_root / "transformers" / "src" + if transformers_src.exists() and str(transformers_src) not in sys.path: + sys.path.append(str(transformers_src)) + + if args.hf_block == "qwen2_moe": + from transformers.models.qwen2_moe.configuration_qwen2_moe import ( + Qwen2MoeConfig, + ) + from transformers.models.qwen2_moe.modeling_qwen2_moe import ( + Qwen2MoeSparseMoeBlock, + ) + + cfg = Qwen2MoeConfig( + hidden_size=args.hidden, + moe_intermediate_size=args.inter, + shared_expert_intermediate_size=args.inter, + num_experts=args.experts, + num_experts_per_tok=args.top_k, + norm_topk_prob=True, + qkv_bias=True, + ) + + block_naive = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype) + block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype) + block_grouped.load_state_dict(block_naive.state_dict()) + + def run_naive_model(inp: torch.Tensor) -> torch.Tensor: + out, _ = block_naive(inp) + return out + + def run_grouped_model(inp: torch.Tensor) -> torch.Tensor: + if tg is None or not tg.available(): + return torch.empty(0) + block_grouped.experts._ax_parent_block = block_grouped + y, _ = tg.moe_ffn_forward_grouped( + inp, block_grouped.gate, block_grouped.experts, block_grouped.top_k + ) + return y if y is not None else torch.empty(0) + + flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k) + print( + f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} " + f"experts={args.experts} top_k={args.top_k} hf_block={args.hf_block}" + ) + + else: + experts = Experts(args.experts, args.hidden, args.inter).to( + device=device, dtype=dtype + ) + gate = nn.Linear(args.hidden, args.experts, bias=False).to( + device=device, dtype=dtype + ) + + def run_naive_model(inp: torch.Tensor) -> torch.Tensor: + return forward_naive(inp, gate, experts, args.top_k) + + def run_grouped_model(inp: torch.Tensor) -> torch.Tensor: + if tg is None or not tg.available(): + return torch.empty(0) + y, _ = tg.moe_ffn_forward_grouped(inp, gate, experts, args.top_k) + return y if y is not None else torch.empty(0) + + flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k) + print( + f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} " + f"experts={args.experts} top_k={args.top_k}" + ) + + # Benchmark naive + t_naive = bench(lambda: run_naive_model(x), iters=args.iters, warmup=args.warmup) tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12) print( f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s" ) - # Prepare reference output once for checks with torch.no_grad(): - y_ref = forward_naive(x, gate, experts, args.top_k) + y_ref = run_naive_model(x) - # torch_grouped backend (PyTorch 2.8+) - try: - from axolotl.kernels.moe import torch_grouped as tg - except Exception: - tg = None + # Benchmark grouped if tg is not None and tg.available(): - - def forward_tg(a, g, ex, topk): - y, _ = tg.moe_ffn_forward_grouped(a, g, ex, topk) - return y - - y_tg = forward_tg(x, gate, experts, args.top_k) - if y_tg is not None: - t_ms = bench( - forward_tg, - x, - gate, - experts, - args.top_k, + y_grouped = run_grouped_model(x) + if y_grouped.numel() == 0: + print("torch_grouped\tN/A (op not callable)") + else: + t_grouped = bench( + lambda: run_grouped_model(x), iters=args.iters, warmup=args.warmup, ) - tflops = flops_total / ((t_ms / 1000.0) * 1e12) - speedup = t_naive / t_ms + tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12) + speedup = t_naive / t_grouped print( - f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×" + f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000):.1f} tok/s\t" + f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×" + ) + diff = (y_ref.float() - y_grouped.float()).abs() + max_abs = diff.max().item() + mean_abs = diff.mean().item() + rel_l2 = ( + (diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item() + ) + print( + f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}" ) - with torch.no_grad(): - y_fast = y_tg - y_ref32 = y_ref.float() - y_fast32 = y_fast.float() - diff = (y_ref32 - y_fast32).abs() - max_abs = diff.max().item() - mean_abs = diff.mean().item() - rel_l2 = ( - (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item() - ) - print( - f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}" - ) - else: - print("torch_grouped\tN/A (op not callable)") else: print("torch_grouped\tN/A (unavailable)") - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True - ) as prof: - forward_naive(x, gate, experts, args.top_k) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) + if args.profile and tg is not None and tg.available(): + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True + ) as prof: + run_naive_model(x) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CUDA], - record_shapes=True, - with_stack=False, - ) as prof: - forward_tg(x, gate, experts, args.top_k) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], + record_shapes=True, + with_stack=False, + ) as prof: + run_grouped_model(x) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) if __name__ == "__main__": diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index b5ea0c532..54299908d 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -267,9 +267,12 @@ def moe_ffn_forward_grouped( ) return None, None - expert_impls = _iter_expert_impls(experts_module) + parent_block = getattr(experts_module, "_ax_parent_block", None) + expert_container = getattr(experts_module, "experts", experts_module) + + expert_impls = _iter_expert_impls(expert_container) sample_mod = expert_impls[0] - storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod) + storage = _ensure_grouped_weights(expert_container, expert_impls, sample_mod) w_gate = storage.gate w_up = storage.up w2 = storage.down @@ -278,11 +281,12 @@ def moe_ffn_forward_grouped( router_logits = gate_linear(x_flat.to(routing_dtype)) shared_out_flat: Optional[torch.Tensor] = None - if hasattr(experts_module, "shared_expert"): - shared_expert = experts_module.shared_expert + shared_owner = parent_block if parent_block is not None else experts_module + if hasattr(shared_owner, "shared_expert"): + shared_expert = shared_owner.shared_expert shared_out_flat = shared_expert(x_flat) shared_out_flat = shared_out_flat.to(expert_dtype) - shared_gate = getattr(experts_module, "shared_expert_gate", None) + shared_gate = getattr(shared_owner, "shared_expert_gate", None) if shared_gate is not None: gate_input = shared_gate(x_flat.to(shared_gate.weight.dtype)) gate_vals = torch.sigmoid(gate_input) diff --git a/src/axolotl/monkeypatch/moe_grouped.py b/src/axolotl/monkeypatch/moe_grouped.py index 0d2d67c97..55eeffa89 100644 --- a/src/axolotl/monkeypatch/moe_grouped.py +++ b/src/axolotl/monkeypatch/moe_grouped.py @@ -76,6 +76,11 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None: @wraps(orig_forward) def _grouped_forward(self, hidden_states: torch.Tensor, *args, **kwargs): bsz, seqlen, hdim = hidden_states.shape + # expose parent block so grouped backend can access shared expert context + try: + self.experts._ax_parent_block = self + except Exception: + pass y, router_logits = _tg.moe_ffn_forward_grouped( hidden_states, self.gate, self.experts, self.top_k )