fix
This commit is contained in:
@@ -1,11 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from axolotl.kernels.moe import torch_grouped as tg
|
||||
except Exception: # pragma: no cover - fallback when torch_grouped unavailable
|
||||
tg = None
|
||||
|
||||
|
||||
class SwiGLUMlp(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int):
|
||||
@@ -51,10 +58,10 @@ def forward_naive(
|
||||
return y.view(bsz, seqlen, hdim)
|
||||
|
||||
|
||||
def bench(fn, *args, iters=50, warmup=10, sync=True):
|
||||
def bench(fn, iters=50, warmup=10, sync=True):
|
||||
# warmup
|
||||
for _ in range(warmup):
|
||||
fn(*args)
|
||||
fn()
|
||||
if sync and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
# measure
|
||||
@@ -63,7 +70,7 @@ def bench(fn, *args, iters=50, warmup=10, sync=True):
|
||||
if sync and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
fn(*args)
|
||||
fn()
|
||||
if sync and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
dt = (time.perf_counter() - t0) * 1000.0
|
||||
@@ -94,6 +101,18 @@ def main():
|
||||
)
|
||||
p.add_argument("--iters", type=int, default=50)
|
||||
p.add_argument("--warmup", type=int, default=10)
|
||||
p.add_argument(
|
||||
"--hf-block",
|
||||
type=str,
|
||||
default="none",
|
||||
choices=["none", "qwen2_moe"],
|
||||
help="Use a Hugging Face MoE block for benchmarking instead of the toy SwiGLU layer.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Capture CUDA profiler tables for naive and grouped runs.",
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -107,101 +126,136 @@ def main():
|
||||
if device == "cuda":
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
# Model
|
||||
experts = Experts(args.experts, args.hidden, args.inter).to(
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
gate = nn.Linear(args.hidden, args.experts, bias=False).to(
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# data
|
||||
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
|
||||
|
||||
# Report config
|
||||
tokens = args.bsz * args.seq
|
||||
print(
|
||||
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} experts={args.experts} top_k={args.top_k}"
|
||||
)
|
||||
|
||||
# Naive baseline
|
||||
t_naive = bench(
|
||||
forward_naive,
|
||||
x,
|
||||
gate,
|
||||
experts,
|
||||
args.top_k,
|
||||
iters=args.iters,
|
||||
warmup=args.warmup,
|
||||
)
|
||||
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
|
||||
use_hf = args.hf_block != "none"
|
||||
|
||||
if use_hf:
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
transformers_src = project_root / "transformers" / "src"
|
||||
if transformers_src.exists() and str(transformers_src) not in sys.path:
|
||||
sys.path.append(str(transformers_src))
|
||||
|
||||
if args.hf_block == "qwen2_moe":
|
||||
from transformers.models.qwen2_moe.configuration_qwen2_moe import (
|
||||
Qwen2MoeConfig,
|
||||
)
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
||||
Qwen2MoeSparseMoeBlock,
|
||||
)
|
||||
|
||||
cfg = Qwen2MoeConfig(
|
||||
hidden_size=args.hidden,
|
||||
moe_intermediate_size=args.inter,
|
||||
shared_expert_intermediate_size=args.inter,
|
||||
num_experts=args.experts,
|
||||
num_experts_per_tok=args.top_k,
|
||||
norm_topk_prob=True,
|
||||
qkv_bias=True,
|
||||
)
|
||||
|
||||
block_naive = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
|
||||
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
|
||||
block_grouped.load_state_dict(block_naive.state_dict())
|
||||
|
||||
def run_naive_model(inp: torch.Tensor) -> torch.Tensor:
|
||||
out, _ = block_naive(inp)
|
||||
return out
|
||||
|
||||
def run_grouped_model(inp: torch.Tensor) -> torch.Tensor:
|
||||
if tg is None or not tg.available():
|
||||
return torch.empty(0)
|
||||
block_grouped.experts._ax_parent_block = block_grouped
|
||||
y, _ = tg.moe_ffn_forward_grouped(
|
||||
inp, block_grouped.gate, block_grouped.experts, block_grouped.top_k
|
||||
)
|
||||
return y if y is not None else torch.empty(0)
|
||||
|
||||
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
|
||||
print(
|
||||
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
|
||||
f"experts={args.experts} top_k={args.top_k} hf_block={args.hf_block}"
|
||||
)
|
||||
|
||||
else:
|
||||
experts = Experts(args.experts, args.hidden, args.inter).to(
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
gate = nn.Linear(args.hidden, args.experts, bias=False).to(
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def run_naive_model(inp: torch.Tensor) -> torch.Tensor:
|
||||
return forward_naive(inp, gate, experts, args.top_k)
|
||||
|
||||
def run_grouped_model(inp: torch.Tensor) -> torch.Tensor:
|
||||
if tg is None or not tg.available():
|
||||
return torch.empty(0)
|
||||
y, _ = tg.moe_ffn_forward_grouped(inp, gate, experts, args.top_k)
|
||||
return y if y is not None else torch.empty(0)
|
||||
|
||||
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
|
||||
print(
|
||||
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
|
||||
f"experts={args.experts} top_k={args.top_k}"
|
||||
)
|
||||
|
||||
# Benchmark naive
|
||||
t_naive = bench(lambda: run_naive_model(x), iters=args.iters, warmup=args.warmup)
|
||||
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
|
||||
print(
|
||||
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
|
||||
)
|
||||
|
||||
# Prepare reference output once for checks
|
||||
with torch.no_grad():
|
||||
y_ref = forward_naive(x, gate, experts, args.top_k)
|
||||
y_ref = run_naive_model(x)
|
||||
|
||||
# torch_grouped backend (PyTorch 2.8+)
|
||||
try:
|
||||
from axolotl.kernels.moe import torch_grouped as tg
|
||||
except Exception:
|
||||
tg = None
|
||||
# Benchmark grouped
|
||||
if tg is not None and tg.available():
|
||||
|
||||
def forward_tg(a, g, ex, topk):
|
||||
y, _ = tg.moe_ffn_forward_grouped(a, g, ex, topk)
|
||||
return y
|
||||
|
||||
y_tg = forward_tg(x, gate, experts, args.top_k)
|
||||
if y_tg is not None:
|
||||
t_ms = bench(
|
||||
forward_tg,
|
||||
x,
|
||||
gate,
|
||||
experts,
|
||||
args.top_k,
|
||||
y_grouped = run_grouped_model(x)
|
||||
if y_grouped.numel() == 0:
|
||||
print("torch_grouped\tN/A (op not callable)")
|
||||
else:
|
||||
t_grouped = bench(
|
||||
lambda: run_grouped_model(x),
|
||||
iters=args.iters,
|
||||
warmup=args.warmup,
|
||||
)
|
||||
tflops = flops_total / ((t_ms / 1000.0) * 1e12)
|
||||
speedup = t_naive / t_ms
|
||||
tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12)
|
||||
speedup = t_naive / t_grouped
|
||||
print(
|
||||
f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
|
||||
f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000):.1f} tok/s\t"
|
||||
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
|
||||
)
|
||||
diff = (y_ref.float() - y_grouped.float()).abs()
|
||||
max_abs = diff.max().item()
|
||||
mean_abs = diff.mean().item()
|
||||
rel_l2 = (
|
||||
(diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item()
|
||||
)
|
||||
print(
|
||||
f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
||||
)
|
||||
with torch.no_grad():
|
||||
y_fast = y_tg
|
||||
y_ref32 = y_ref.float()
|
||||
y_fast32 = y_fast.float()
|
||||
diff = (y_ref32 - y_fast32).abs()
|
||||
max_abs = diff.max().item()
|
||||
mean_abs = diff.mean().item()
|
||||
rel_l2 = (
|
||||
(diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
|
||||
)
|
||||
print(
|
||||
f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
||||
)
|
||||
else:
|
||||
print("torch_grouped\tN/A (op not callable)")
|
||||
else:
|
||||
print("torch_grouped\tN/A (unavailable)")
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
|
||||
) as prof:
|
||||
forward_naive(x, gate, experts, args.top_k)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
||||
if args.profile and tg is not None and tg.available():
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
|
||||
) as prof:
|
||||
run_naive_model(x)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
with_stack=False,
|
||||
) as prof:
|
||||
forward_tg(x, gate, experts, args.top_k)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
with_stack=False,
|
||||
) as prof:
|
||||
run_grouped_model(x)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -267,9 +267,12 @@ def moe_ffn_forward_grouped(
|
||||
)
|
||||
return None, None
|
||||
|
||||
expert_impls = _iter_expert_impls(experts_module)
|
||||
parent_block = getattr(experts_module, "_ax_parent_block", None)
|
||||
expert_container = getattr(experts_module, "experts", experts_module)
|
||||
|
||||
expert_impls = _iter_expert_impls(expert_container)
|
||||
sample_mod = expert_impls[0]
|
||||
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod)
|
||||
storage = _ensure_grouped_weights(expert_container, expert_impls, sample_mod)
|
||||
w_gate = storage.gate
|
||||
w_up = storage.up
|
||||
w2 = storage.down
|
||||
@@ -278,11 +281,12 @@ def moe_ffn_forward_grouped(
|
||||
router_logits = gate_linear(x_flat.to(routing_dtype))
|
||||
|
||||
shared_out_flat: Optional[torch.Tensor] = None
|
||||
if hasattr(experts_module, "shared_expert"):
|
||||
shared_expert = experts_module.shared_expert
|
||||
shared_owner = parent_block if parent_block is not None else experts_module
|
||||
if hasattr(shared_owner, "shared_expert"):
|
||||
shared_expert = shared_owner.shared_expert
|
||||
shared_out_flat = shared_expert(x_flat)
|
||||
shared_out_flat = shared_out_flat.to(expert_dtype)
|
||||
shared_gate = getattr(experts_module, "shared_expert_gate", None)
|
||||
shared_gate = getattr(shared_owner, "shared_expert_gate", None)
|
||||
if shared_gate is not None:
|
||||
gate_input = shared_gate(x_flat.to(shared_gate.weight.dtype))
|
||||
gate_vals = torch.sigmoid(gate_input)
|
||||
|
||||
@@ -76,6 +76,11 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
|
||||
@wraps(orig_forward)
|
||||
def _grouped_forward(self, hidden_states: torch.Tensor, *args, **kwargs):
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
# expose parent block so grouped backend can access shared expert context
|
||||
try:
|
||||
self.experts._ax_parent_block = self
|
||||
except Exception:
|
||||
pass
|
||||
y, router_logits = _tg.moe_ffn_forward_grouped(
|
||||
hidden_states, self.gate, self.experts, self.top_k
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user