token shuffle kernel
This commit is contained in:
@@ -36,8 +36,8 @@ DTYPE_MAP = {
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--batch", type=int, default=2, help="batch size")
|
||||
parser.add_argument("--seq-len", type=int, default=256, help="sequence length")
|
||||
parser.add_argument("--batch", type=int, default=8, help="batch size")
|
||||
parser.add_argument("--seq-len", type=int, default=2048, help="sequence length")
|
||||
parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size")
|
||||
parser.add_argument(
|
||||
"--moe-intermediate-size",
|
||||
@@ -48,13 +48,13 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument(
|
||||
"--n-experts",
|
||||
type=int,
|
||||
default=64,
|
||||
default=256,
|
||||
help="Number of routed experts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=4,
|
||||
default=8,
|
||||
help="Number of experts per token",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -153,6 +153,10 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
baseline_module.to(device=device, dtype=dtype)
|
||||
patched_module.to(device=device, dtype=dtype)
|
||||
|
||||
tokens = args.batch * args.seq_len
|
||||
routed_tokens = tokens * args.top_k
|
||||
avg_tokens_per_expert = routed_tokens / args.n_experts
|
||||
|
||||
inputs = torch.randn(
|
||||
args.batch,
|
||||
args.seq_len,
|
||||
@@ -174,6 +178,9 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
print(
|
||||
f"Device={device.type} dtype={dtype} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
|
||||
)
|
||||
print(
|
||||
f"routed tokens={routed_tokens} avg tokens/expert={avg_tokens_per_expert:.1f} group_size={args.group_size}"
|
||||
)
|
||||
print(
|
||||
f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user