uniform routing default
This commit is contained in:
@@ -50,9 +50,9 @@ def parse_args() -> argparse.Namespace:
|
|||||||
help="Override GROUP_SIZE_M for every configuration",
|
help="Override GROUP_SIZE_M for every configuration",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--uniform-routing",
|
"--no-uniform-routing",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Force uniform routing for every configuration",
|
help="Disable uniform routing for every configuration",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--include-mixtral-long",
|
"--include-mixtral-long",
|
||||||
@@ -79,7 +79,7 @@ def make_namespace(
|
|||||||
"warmup": args.warmup,
|
"warmup": args.warmup,
|
||||||
"iters": args.iters,
|
"iters": args.iters,
|
||||||
"seed": args.seed,
|
"seed": args.seed,
|
||||||
"uniform_routing": args.uniform_routing,
|
"uniform_routing": not args.no_uniform_routing,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if args.group_size is not None:
|
if args.group_size is not None:
|
||||||
@@ -196,8 +196,9 @@ def main() -> None: # pragma: no cover - utility script
|
|||||||
)
|
)
|
||||||
rows = []
|
rows = []
|
||||||
|
|
||||||
|
uniform_flag = not args.no_uniform_routing
|
||||||
print(
|
print(
|
||||||
f"Running sweep on device={args.device} dtype={args.dtype} backends={BACKENDS} uniform_routing={args.uniform_routing}"
|
f"Running sweep on device={args.device} dtype={args.dtype} backends={BACKENDS} uniform_routing={uniform_flag}"
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
|
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ def _run_cg_grouped_gemm(
|
|||||||
group_size_m: int,
|
group_size_m: int,
|
||||||
hidden_dtype: torch.dtype,
|
hidden_dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
_ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg")
|
_ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg")
|
||||||
|
|
||||||
expert_index_tensor = torch.repeat_interleave(
|
expert_index_tensor = torch.repeat_interleave(
|
||||||
@@ -167,17 +167,11 @@ def _run_cg_grouped_gemm(
|
|||||||
expert_index_tensor,
|
expert_index_tensor,
|
||||||
group_size_m,
|
group_size_m,
|
||||||
)
|
)
|
||||||
down_out = ContiguousGroupedGEMM.apply(
|
|
||||||
grouped_hidden,
|
|
||||||
down_weights,
|
|
||||||
expert_index_tensor,
|
|
||||||
group_size_m,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
gate_out.to(hidden_dtype),
|
gate_out.to(hidden_dtype),
|
||||||
up_out.to(hidden_dtype),
|
up_out.to(hidden_dtype),
|
||||||
down_out.to(hidden_dtype),
|
down_weights,
|
||||||
|
expert_index_tensor,
|
||||||
)
|
)
|
||||||
|
|
||||||
gate_out = mg_grouped_gemm(
|
gate_out = mg_grouped_gemm(
|
||||||
@@ -296,7 +290,7 @@ def _moe_triton_forward(
|
|||||||
m_sizes_tensor,
|
m_sizes_tensor,
|
||||||
).to(hidden_dtype)
|
).to(hidden_dtype)
|
||||||
else:
|
else:
|
||||||
gate_out, up_out, down_out_cg = _run_cg_grouped_gemm(
|
gate_out, up_out, down_weights, expert_index_tensor = _run_cg_grouped_gemm(
|
||||||
module,
|
module,
|
||||||
grouped_hidden,
|
grouped_hidden,
|
||||||
m_sizes,
|
m_sizes,
|
||||||
@@ -330,7 +324,12 @@ def _moe_triton_forward(
|
|||||||
m_sizes_tensor,
|
m_sizes_tensor,
|
||||||
).to(hidden_dtype)
|
).to(hidden_dtype)
|
||||||
else:
|
else:
|
||||||
down_out = down_out_cg
|
down_out = ContiguousGroupedGEMM.apply(
|
||||||
|
hidden_grouped,
|
||||||
|
down_weights,
|
||||||
|
expert_index_tensor,
|
||||||
|
group_size_m,
|
||||||
|
).to(hidden_dtype)
|
||||||
|
|
||||||
if valid_positions.numel() > 0:
|
if valid_positions.numel() > 0:
|
||||||
down_valid = down_out.index_select(0, valid_positions)
|
down_valid = down_out.index_select(0, valid_positions)
|
||||||
|
|||||||
Reference in New Issue
Block a user