bench fix
This commit is contained in:
@@ -38,7 +38,15 @@ def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> floa
|
||||
return 6.0 * tokens * top_k * hidden * inter
|
||||
|
||||
|
||||
def load_hf_block(hidden: int, inter: int, experts: int, top_k: int, *, device: torch.device, dtype: torch.dtype):
|
||||
def load_hf_block(
|
||||
hidden: int,
|
||||
inter: int,
|
||||
experts: int,
|
||||
top_k: int,
|
||||
*,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
transformers_src = project_root / "transformers" / "src"
|
||||
if transformers_src.exists() and str(transformers_src) not in sys.path:
|
||||
@@ -114,12 +122,16 @@ def main() -> None:
|
||||
if tg is None or not tg.available():
|
||||
return torch.empty(0)
|
||||
block_grouped.experts._ax_parent_block = block_grouped
|
||||
y, _ = tg.moe_ffn_forward_grouped(x, block_grouped.gate, block_grouped.experts, block_grouped.top_k)
|
||||
y, _ = tg.moe_ffn_forward_grouped(
|
||||
x, block_grouped.gate, block_grouped.experts, block_grouped.top_k
|
||||
)
|
||||
return y if y is not None else torch.empty(0)
|
||||
|
||||
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)
|
||||
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
|
||||
print(f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s")
|
||||
print(
|
||||
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
y_ref = run_naive()
|
||||
|
||||
@@ -33,7 +33,7 @@ def main() -> None:
|
||||
|
||||
block = Qwen2MoeSparseMoeBlock(cfg).to("cuda", dtype=torch.bfloat16)
|
||||
experts = block.experts
|
||||
setattr(experts, "_ax_parent_block", block)
|
||||
experts._ax_parent_block = block
|
||||
|
||||
impls = _iter_expert_impls(experts)
|
||||
print(f"impl count: {len(impls)}")
|
||||
|
||||
@@ -30,7 +30,17 @@ def available() -> bool:
|
||||
def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
|
||||
impls: List[torch.nn.Module] = []
|
||||
for exp in experts_module:
|
||||
impls.append(getattr(exp, "mlp", getattr(exp, "ffn", exp)))
|
||||
candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp))
|
||||
if hasattr(candidate, "gate_proj") and hasattr(candidate, "up_proj"):
|
||||
impls.append(candidate)
|
||||
continue
|
||||
nested = getattr(candidate, "experts", None)
|
||||
if nested is not None:
|
||||
impls.extend(_iter_expert_impls(nested))
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"torch_grouped: unable to resolve expert implementation for module"
|
||||
)
|
||||
return impls
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user