token shuffle kernel

This commit is contained in:
Dan Saunders
2025-09-21 16:46:46 -04:00
parent 18269ee6a9
commit 5c74edeefe
5 changed files with 231 additions and 40 deletions

View File

@@ -36,8 +36,8 @@ DTYPE_MAP = {
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--batch", type=int, default=2, help="batch size")
parser.add_argument("--seq-len", type=int, default=256, help="sequence length")
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--seq-len", type=int, default=2048, help="sequence length")
parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size")
parser.add_argument(
"--moe-intermediate-size",
@@ -48,13 +48,13 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--n-experts",
type=int,
default=64,
default=256,
help="Number of routed experts",
)
parser.add_argument(
"--top-k",
type=int,
default=4,
default=8,
help="Number of experts per token",
)
parser.add_argument(
@@ -153,6 +153,10 @@ def main() -> None: # pragma: no cover - CLI entrypoint
baseline_module.to(device=device, dtype=dtype)
patched_module.to(device=device, dtype=dtype)
tokens = args.batch * args.seq_len
routed_tokens = tokens * args.top_k
avg_tokens_per_expert = routed_tokens / args.n_experts
inputs = torch.randn(
args.batch,
args.seq_len,
@@ -174,6 +178,9 @@ def main() -> None: # pragma: no cover - CLI entrypoint
print(
f"Device={device.type} dtype={dtype} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
)
print(
f"routed tokens={routed_tokens} avg tokens/expert={avg_tokens_per_expert:.1f} group_size={args.group_size}"
)
print(
f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}"
)