bench fix
This commit is contained in:
@@ -38,7 +38,15 @@ def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> floa
|
|||||||
return 6.0 * tokens * top_k * hidden * inter
|
return 6.0 * tokens * top_k * hidden * inter
|
||||||
|
|
||||||
|
|
||||||
def load_hf_block(hidden: int, inter: int, experts: int, top_k: int, *, device: torch.device, dtype: torch.dtype):
|
def load_hf_block(
|
||||||
|
hidden: int,
|
||||||
|
inter: int,
|
||||||
|
experts: int,
|
||||||
|
top_k: int,
|
||||||
|
*,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
project_root = Path(__file__).resolve().parents[2]
|
project_root = Path(__file__).resolve().parents[2]
|
||||||
transformers_src = project_root / "transformers" / "src"
|
transformers_src = project_root / "transformers" / "src"
|
||||||
if transformers_src.exists() and str(transformers_src) not in sys.path:
|
if transformers_src.exists() and str(transformers_src) not in sys.path:
|
||||||
@@ -114,12 +122,16 @@ def main() -> None:
|
|||||||
if tg is None or not tg.available():
|
if tg is None or not tg.available():
|
||||||
return torch.empty(0)
|
return torch.empty(0)
|
||||||
block_grouped.experts._ax_parent_block = block_grouped
|
block_grouped.experts._ax_parent_block = block_grouped
|
||||||
y, _ = tg.moe_ffn_forward_grouped(x, block_grouped.gate, block_grouped.experts, block_grouped.top_k)
|
y, _ = tg.moe_ffn_forward_grouped(
|
||||||
|
x, block_grouped.gate, block_grouped.experts, block_grouped.top_k
|
||||||
|
)
|
||||||
return y if y is not None else torch.empty(0)
|
return y if y is not None else torch.empty(0)
|
||||||
|
|
||||||
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)
|
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)
|
||||||
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
|
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
|
||||||
print(f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s")
|
print(
|
||||||
|
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
|
||||||
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
y_ref = run_naive()
|
y_ref = run_naive()
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ def main() -> None:
|
|||||||
|
|
||||||
block = Qwen2MoeSparseMoeBlock(cfg).to("cuda", dtype=torch.bfloat16)
|
block = Qwen2MoeSparseMoeBlock(cfg).to("cuda", dtype=torch.bfloat16)
|
||||||
experts = block.experts
|
experts = block.experts
|
||||||
setattr(experts, "_ax_parent_block", block)
|
experts._ax_parent_block = block
|
||||||
|
|
||||||
impls = _iter_expert_impls(experts)
|
impls = _iter_expert_impls(experts)
|
||||||
print(f"impl count: {len(impls)}")
|
print(f"impl count: {len(impls)}")
|
||||||
|
|||||||
@@ -30,7 +30,17 @@ def available() -> bool:
|
|||||||
def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
|
def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
|
||||||
impls: List[torch.nn.Module] = []
|
impls: List[torch.nn.Module] = []
|
||||||
for exp in experts_module:
|
for exp in experts_module:
|
||||||
impls.append(getattr(exp, "mlp", getattr(exp, "ffn", exp)))
|
candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp))
|
||||||
|
if hasattr(candidate, "gate_proj") and hasattr(candidate, "up_proj"):
|
||||||
|
impls.append(candidate)
|
||||||
|
continue
|
||||||
|
nested = getattr(candidate, "experts", None)
|
||||||
|
if nested is not None:
|
||||||
|
impls.extend(_iter_expert_impls(nested))
|
||||||
|
continue
|
||||||
|
raise RuntimeError(
|
||||||
|
"torch_grouped: unable to resolve expert implementation for module"
|
||||||
|
)
|
||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user