diff --git a/scripts/benchmarks/deepseek_v3_moe_sweep.py b/scripts/benchmarks/deepseek_v3_moe_sweep.py index 162a3a53b..c0885152f 100644 --- a/scripts/benchmarks/deepseek_v3_moe_sweep.py +++ b/scripts/benchmarks/deepseek_v3_moe_sweep.py @@ -50,9 +50,9 @@ def parse_args() -> argparse.Namespace: help="Override GROUP_SIZE_M for every configuration", ) parser.add_argument( - "--uniform-routing", + "--no-uniform-routing", action="store_true", - help="Force uniform routing for every configuration", + help="Disable uniform routing for every configuration", ) parser.add_argument( "--include-mixtral-long", @@ -79,7 +79,7 @@ def make_namespace( "warmup": args.warmup, "iters": args.iters, "seed": args.seed, - "uniform_routing": args.uniform_routing, + "uniform_routing": not args.no_uniform_routing, } ) if args.group_size is not None: @@ -196,8 +196,9 @@ def main() -> None: # pragma: no cover - utility script ) rows = [] + uniform_flag = not args.no_uniform_routing print( - f"Running sweep on device={args.device} dtype={args.dtype} backends={BACKENDS} uniform_routing={args.uniform_routing}" + f"Running sweep on device={args.device} dtype={args.dtype} backends={BACKENDS} uniform_routing={uniform_flag}" ) print( f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}" diff --git a/src/axolotl/monkeypatch/deepseek_v3/__init__.py b/src/axolotl/monkeypatch/deepseek_v3/__init__.py index eceba8e7b..6d1e97d65 100644 --- a/src/axolotl/monkeypatch/deepseek_v3/__init__.py +++ b/src/axolotl/monkeypatch/deepseek_v3/__init__.py @@ -132,7 +132,7 @@ def _run_cg_grouped_gemm( group_size_m: int, hidden_dtype: torch.dtype, device: torch.device, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: _ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg") expert_index_tensor = torch.repeat_interleave( @@ -167,17 +167,11 @@ def _run_cg_grouped_gemm( expert_index_tensor, group_size_m, ) - down_out = ContiguousGroupedGEMM.apply( - grouped_hidden, - down_weights, - expert_index_tensor, - group_size_m, - ) - return ( gate_out.to(hidden_dtype), up_out.to(hidden_dtype), - down_out.to(hidden_dtype), + down_weights, + expert_index_tensor, ) gate_out = mg_grouped_gemm( @@ -296,7 +290,7 @@ def _moe_triton_forward( m_sizes_tensor, ).to(hidden_dtype) else: - gate_out, up_out, down_out_cg = _run_cg_grouped_gemm( + gate_out, up_out, down_weights, expert_index_tensor = _run_cg_grouped_gemm( module, grouped_hidden, m_sizes, @@ -330,7 +324,12 @@ def _moe_triton_forward( m_sizes_tensor, ).to(hidden_dtype) else: - down_out = down_out_cg + down_out = ContiguousGroupedGEMM.apply( + hidden_grouped, + down_weights, + expert_index_tensor, + group_size_m, + ).to(hidden_dtype) if valid_positions.numel() > 0: down_valid = down_out.index_select(0, valid_positions)