minify
This commit is contained in:
@@ -54,7 +54,7 @@ def forward_naive(
|
||||
def bench(fn, *args, iters=50, warmup=10, sync=True):
|
||||
# warmup
|
||||
for _ in range(warmup):
|
||||
out = fn(*args)
|
||||
fn(*args)
|
||||
if sync and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
# measure
|
||||
@@ -63,7 +63,7 @@ def bench(fn, *args, iters=50, warmup=10, sync=True):
|
||||
if sync and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
out = fn(*args)
|
||||
fn(*args)
|
||||
if sync and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
dt = (time.perf_counter() - t0) * 1000.0
|
||||
@@ -185,12 +185,7 @@ def main():
|
||||
f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
from axolotl.kernels.moe.torch_grouped import LAST_ERROR as _TG_ERR
|
||||
except Exception:
|
||||
_TG_ERR = None
|
||||
reason = f" (reason: {_TG_ERR})" if _TG_ERR else ""
|
||||
print(f"torch_grouped\tN/A (op not callable){reason}")
|
||||
print("torch_grouped\tN/A (op not callable)")
|
||||
else:
|
||||
print("torch_grouped\tN/A (unavailable)")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user