cfg value
This commit is contained in:
@@ -34,8 +34,15 @@ SCENARIOS: tuple[Scenario, ...] = (
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--device", default="cuda", choices=["cuda"], help="Execution device")
|
||||
parser.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"], help="Computation dtype")
|
||||
parser.add_argument(
|
||||
"--device", default="cuda", choices=["cuda"], help="Execution device"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="bf16",
|
||||
choices=["bf16", "fp16", "fp32"],
|
||||
help="Computation dtype",
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
|
||||
parser.add_argument("--iters", type=int, default=20, help="Benchmark iterations")
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
@@ -56,7 +63,9 @@ def pick_dtype(name: str) -> torch.dtype:
|
||||
}[name]
|
||||
|
||||
|
||||
def make_indices(num_groups: int, group_size: int, device: torch.device) -> torch.Tensor:
|
||||
def make_indices(
|
||||
num_groups: int, group_size: int, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
indices = torch.arange(num_groups, device=device, dtype=torch.int32)
|
||||
return indices.repeat_interleave(group_size)
|
||||
|
||||
@@ -82,7 +91,9 @@ def run_scenario(
|
||||
group_size_m: int,
|
||||
) -> dict:
|
||||
if scenario.m % scenario.num_groups != 0:
|
||||
raise ValueError(f"M ({scenario.m}) not divisible by groups ({scenario.num_groups})")
|
||||
raise ValueError(
|
||||
f"M ({scenario.m}) not divisible by groups ({scenario.num_groups})"
|
||||
)
|
||||
group_size = scenario.m // scenario.num_groups
|
||||
if group_size % group_size_m != 0:
|
||||
raise ValueError(
|
||||
@@ -90,7 +101,9 @@ def run_scenario(
|
||||
)
|
||||
|
||||
inputs = torch.randn(scenario.m, scenario.k, device=device, dtype=dtype)
|
||||
weights = torch.randn(scenario.num_groups, scenario.n, scenario.k, device=device, dtype=dtype)
|
||||
weights = torch.randn(
|
||||
scenario.num_groups, scenario.n, scenario.k, device=device, dtype=dtype
|
||||
)
|
||||
indices = make_indices(scenario.num_groups, group_size, device)
|
||||
|
||||
def persistent():
|
||||
|
||||
Reference in New Issue
Block a user