diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index 78fdff374..1815096ec 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -38,7 +38,15 @@ def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> floa return 6.0 * tokens * top_k * hidden * inter -def load_hf_block(hidden: int, inter: int, experts: int, top_k: int, *, device: torch.device, dtype: torch.dtype): +def load_hf_block( + hidden: int, + inter: int, + experts: int, + top_k: int, + *, + device: torch.device, + dtype: torch.dtype, +): project_root = Path(__file__).resolve().parents[2] transformers_src = project_root / "transformers" / "src" if transformers_src.exists() and str(transformers_src) not in sys.path: @@ -114,12 +122,16 @@ def main() -> None: if tg is None or not tg.available(): return torch.empty(0) block_grouped.experts._ax_parent_block = block_grouped - y, _ = tg.moe_ffn_forward_grouped(x, block_grouped.gate, block_grouped.experts, block_grouped.top_k) + y, _ = tg.moe_ffn_forward_grouped( + x, block_grouped.gate, block_grouped.experts, block_grouped.top_k + ) return y if y is not None else torch.empty(0) t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup) tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12) - print(f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s") + print( + f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s" + ) with torch.no_grad(): y_ref = run_naive() diff --git a/scripts/debug_qwen2_experts.py b/scripts/debug_qwen2_experts.py index 9b9057689..7ec22e0bc 100644 --- a/scripts/debug_qwen2_experts.py +++ b/scripts/debug_qwen2_experts.py @@ -33,7 +33,7 @@ def main() -> None: block = Qwen2MoeSparseMoeBlock(cfg).to("cuda", dtype=torch.bfloat16) experts = block.experts - setattr(experts, "_ax_parent_block", block) + experts._ax_parent_block = block impls = _iter_expert_impls(experts) print(f"impl count: {len(impls)}") diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 54299908d..2f4d8c6cb 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -30,7 +30,17 @@ def available() -> bool: def _iter_expert_impls(experts_module) -> List[torch.nn.Module]: impls: List[torch.nn.Module] = [] for exp in experts_module: - impls.append(getattr(exp, "mlp", getattr(exp, "ffn", exp))) + candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp)) + if hasattr(candidate, "gate_proj") and hasattr(candidate, "up_proj"): + impls.append(candidate) + continue + nested = getattr(candidate, "experts", None) + if nested is not None: + impls.extend(_iter_expert_impls(nested)) + continue + raise RuntimeError( + "torch_grouped: unable to resolve expert implementation for module" + ) return impls