This commit is contained in:
Dan Saunders
2025-09-17 16:42:35 -04:00
parent 129db67705
commit fd87eed501
5 changed files with 91 additions and 297 deletions

View File

@@ -54,7 +54,7 @@ def forward_naive(
def bench(fn, *args, iters=50, warmup=10, sync=True):
# warmup
for _ in range(warmup):
out = fn(*args)
fn(*args)
if sync and torch.cuda.is_available():
torch.cuda.synchronize()
# measure
@@ -63,7 +63,7 @@ def bench(fn, *args, iters=50, warmup=10, sync=True):
if sync and torch.cuda.is_available():
torch.cuda.synchronize()
t0 = time.perf_counter()
out = fn(*args)
fn(*args)
if sync and torch.cuda.is_available():
torch.cuda.synchronize()
dt = (time.perf_counter() - t0) * 1000.0
@@ -185,12 +185,7 @@ def main():
f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
else:
try:
from axolotl.kernels.moe.torch_grouped import LAST_ERROR as _TG_ERR
except Exception:
_TG_ERR = None
reason = f" (reason: {_TG_ERR})" if _TG_ERR else ""
print(f"torch_grouped\tN/A (op not callable){reason}")
print("torch_grouped\tN/A (op not callable)")
else:
print("torch_grouped\tN/A (unavailable)")