bench fix

This commit is contained in:
Dan Saunders
2025-09-19 12:34:08 -04:00
parent 1e7302d30a
commit 42aadc5069
3 changed files with 27 additions and 5 deletions

View File

@@ -38,7 +38,15 @@ def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> floa
return 6.0 * tokens * top_k * hidden * inter
def load_hf_block(hidden: int, inter: int, experts: int, top_k: int, *, device: torch.device, dtype: torch.dtype):
def load_hf_block(
hidden: int,
inter: int,
experts: int,
top_k: int,
*,
device: torch.device,
dtype: torch.dtype,
):
project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
@@ -114,12 +122,16 @@ def main() -> None:
if tg is None or not tg.available():
return torch.empty(0)
block_grouped.experts._ax_parent_block = block_grouped
y, _ = tg.moe_ffn_forward_grouped(x, block_grouped.gate, block_grouped.experts, block_grouped.top_k)
y, _ = tg.moe_ffn_forward_grouped(
x, block_grouped.gate, block_grouped.experts, block_grouped.top_k
)
return y if y is not None else torch.empty(0)
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
print(f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s")
print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
)
with torch.no_grad():
y_ref = run_naive()

View File

@@ -33,7 +33,7 @@ def main() -> None:
block = Qwen2MoeSparseMoeBlock(cfg).to("cuda", dtype=torch.bfloat16)
experts = block.experts
setattr(experts, "_ax_parent_block", block)
experts._ax_parent_block = block
impls = _iter_expert_impls(experts)
print(f"impl count: {len(impls)}")

View File

@@ -30,7 +30,17 @@ def available() -> bool:
def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
impls: List[torch.nn.Module] = []
for exp in experts_module:
impls.append(getattr(exp, "mlp", getattr(exp, "ffn", exp)))
candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp))
if hasattr(candidate, "gate_proj") and hasattr(candidate, "up_proj"):
impls.append(candidate)
continue
nested = getattr(candidate, "experts", None)
if nested is not None:
impls.extend(_iter_expert_impls(nested))
continue
raise RuntimeError(
"torch_grouped: unable to resolve expert implementation for module"
)
return impls