Files
axolotl/scripts/analyze_profile.py
2026-04-10 16:46:17 -04:00

1519 lines
53 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Axolotl Training Profiler Analyzer
===================================
Analyzes PyTorch profiler output from axolotl training runs (profiler_steps config).
Produces breakdowns by CUDA kernel category, identifies bottlenecks, and optionally
compares two traces (e.g. before/after optimization).
Supports both:
- profiler_trace.json (torch.profiler Chrome trace -- timing analysis)
- snapshot.pickle (torch.cuda.memory._snapshot -- memory analysis)
Usage:
# Analyze a single trace
python analyze_profile.py outputs/qwen35_moe_profile/
# Compare before vs after
python analyze_profile.py outputs/before/ --compare outputs/after/
# Include step 0 (warmup/compilation) in analysis
python analyze_profile.py outputs/run/ --include-warmup
# Memory-only analysis
python analyze_profile.py outputs/run/ --memory-only
# Quick mode (first 2M events only, for large traces)
python analyze_profile.py outputs/run/ --quick
"""
import argparse
import json
import pickle # nosec B403
import time
from collections import defaultdict
from pathlib import Path
# ---- Kernel categorization ------------------------------------------------
KERNEL_CATEGORIES = [
# ScatterMoE -- ordered most specific first
(
"ScatterMoE bwd LoRA (split)",
["_group_bwd_lora_split", "_group_bwd_da", "_group_bwd_db"],
),
("ScatterMoE bwd LoRA (fused)", ["_group_bwd_lora"]),
("ScatterMoE bwd dX", ["_scatter2scatter_lora_dx"]),
("ScatterMoE fwd", ["_scatter2scatter_lora"]),
# Quantization
("BnB Dequantization", ["dequantize", "kDequantizeBlockwise"]),
# Attention
("Flash Attention", ["flash", "fmha"]),
# Loss
("CCE Loss", ["_cce_"]),
# LoRA fused kernels (autograd.Function based)
("LoRA QKV Kernel", ["lora_qkv"]),
("LoRA O Kernel", ["lora_o"]),
("LoRA MLP Kernel", ["lora_mlp"]),
# LoRA activation kernels (SwiGLU/GEGLU Triton kernels)
("LoRA Activation (SwiGLU/GEGLU)", ["swiglu", "geglu"]),
# DoRA weight norm
("DoRA Weight Norm", ["linalg_norm", "dora_scale"]),
# Compute
("GEMM/CUTLASS", ["cutlass", "gemm", "gemv", "cublas"]),
("Triton (norms etc)", ["triton"]),
("Conv1d", ["conv1d", "causal_conv"]),
# Optimizer
("Optimizer", ["adam", "optim"]),
# Dtype conversion (fp32→bf16 LoRA matrix casts etc)
("Dtype Conversion", ["_to_copy", "to_copy"]),
# Memory
("Elementwise/Fill", ["fill", "elementwise", "cast", "copy_kernel"]),
("Memory ops", ["memcpy", "memset"]),
# Routing
("TopK/Sort", ["topk", "sort"]),
("Index/Gather/Scatter", ["index", "gather", "scatter"]),
]
# Categories to keep when pre-filtering during streaming load
_KEEP_CATS = {"kernel", "gpu_memcpy", "cpu_op", "python_function", "ac2g", "Runtime"}
def categorize_kernel(name):
nl = name.lower()
for cat_name, patterns in KERNEL_CATEGORIES:
if any(p in nl for p in patterns):
return cat_name
return "Other"
# ---- Trace loading ---------------------------------------------------------
def _try_ijson_load(trace_file, quick=False, max_events=2_000_000):
"""Stream-parse trace JSON with ijson. Returns filtered events list."""
try:
import ijson
except ImportError:
return None
events = []
count = 0
with open(trace_file, "rb") as f:
for ev in ijson.items(f, "traceEvents.item"):
count += 1
cat = ev.get("cat", "")
# Pre-filter: only keep categories we care about
if cat in _KEEP_CATS or ev.get("ph") == "M":
events.append(ev)
if quick and count >= max_events:
print(f" --quick: stopped after {count:,} events")
break
return events, count
def load_trace(path, quick=False):
trace_file = (
Path(path) / "profiler_trace.json" if Path(path).is_dir() else Path(path)
)
if not trace_file.exists():
return None
size_gb = trace_file.stat().st_size / 1e9
print(f"Loading {trace_file.name} ({size_gb:.1f} GB)...")
t0 = time.monotonic()
# Try streaming parser first for large files
if size_gb > 0.5:
result = _try_ijson_load(trace_file, quick=quick)
if result is not None:
events, total_count = result
elapsed = time.monotonic() - t0
print(
f" {total_count:,} total events, {len(events):,} kept "
f"(streamed in {elapsed:.1f}s)"
)
return events
# Fallback: standard json.load
with open(trace_file) as f:
data = json.load(f)
all_events = data.get("traceEvents", [])
elapsed = time.monotonic() - t0
if quick:
all_events = all_events[:2_000_000]
print(f" --quick: limited to first {len(all_events):,} events")
# Pre-filter
events = [
ev
for ev in all_events
if ev.get("cat", "") in _KEEP_CATS or ev.get("ph") == "M"
]
print(
f" {len(all_events):,} total events, {len(events):,} kept "
f"(loaded in {elapsed:.1f}s)"
)
return events
# ---- Trace analysis -------------------------------------------------------
def _estimate_n_steps(cuda_events):
"""Estimate the number of training steps from CUDA event timestamps.
Detects step boundaries by looking for large gaps (>2x median gap) in the
sorted timestamp sequence of CUDA kernels.
"""
if len(cuda_events) < 100:
return 1
timestamps = sorted(float(ev.get("ts", 0)) for ev in cuda_events)
# Compute gaps between consecutive events
gaps = [timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)]
if not gaps:
return 1
median_gap = sorted(gaps)[len(gaps) // 2]
# A step boundary has a gap much larger than the median inter-kernel gap
threshold = max(median_gap * 50, 100_000) # at least 100ms
n_boundaries = sum(1 for g in gaps if g > threshold)
return max(n_boundaries + 1, 1)
def analyze_trace(events, skip_warmup=True):
cuda_events = [
ev
for ev in events
if ev.get("ph") == "X" and ev.get("cat") in ("kernel", "gpu_memcpy")
]
if not cuda_events:
print(" No CUDA kernel events found!")
return None
cutoff_ts = None
if skip_warmup and len(cuda_events) > 1000:
timestamps = sorted(set(float(ev.get("ts", 0)) for ev in cuda_events))
min_ts, max_ts = timestamps[0], timestamps[-1]
total_span = max_ts - min_ts
# Step 0 is warmup (Triton compilation + autotune). It's typically
# the slowest step by far. Use 45% of wall-clock as cutoff -- step 0
# usually takes >50% of total time when it includes compilation.
cutoff_ts = min_ts + total_span * 0.45
before = len(cuda_events)
cuda_events = [ev for ev in cuda_events if float(ev.get("ts", 0)) > cutoff_ts]
print(f" Excluding step 0 (warmup): {before:,} -> {len(cuda_events):,} events")
n_steps_profiled = _estimate_n_steps(cuda_events)
# Aggregate by kernel (cast to float to handle ijson Decimal values)
kernel_stats = defaultdict(lambda: {"total_us": 0.0, "count": 0, "max_us": 0.0})
for ev in cuda_events:
name = ev.get("name", "unknown")
dur = float(ev.get("dur", 0))
kernel_stats[name]["total_us"] += dur
kernel_stats[name]["count"] += 1
kernel_stats[name]["max_us"] = max(kernel_stats[name]["max_us"], dur)
total_cuda = sum(v["total_us"] for v in kernel_stats.values())
# Group by category
cat_stats = defaultdict(lambda: {"total_us": 0, "count": 0})
for name, info in kernel_stats.items():
cat = categorize_kernel(name)
cat_stats[cat]["total_us"] += info["total_us"]
cat_stats[cat]["count"] += info["count"]
# Fill/zero_ analysis: find FillFunctor kernels and group by tensor size
fill_by_size = defaultdict(lambda: {"total_us": 0, "count": 0})
for ev in cuda_events:
name = ev.get("name", "")
if "FillFunctor" not in name and "fill" not in name.lower():
continue
# Extract Input Dims from args if present
args = ev.get("args", {})
input_dims = args.get("Input Dims", args.get("input_dims", "unknown"))
if isinstance(input_dims, list):
input_dims = str(input_dims)
dur = float(ev.get("dur", 0))
fill_by_size[input_dims]["total_us"] += dur
fill_by_size[input_dims]["count"] += 1
# CPU op analysis for wall-clock estimation (apply same warmup cutoff)
cpu_ops = [
ev
for ev in events
if ev.get("ph") == "X" and ev.get("cat") in ("cpu_op", "python_function")
]
if cutoff_ts is not None:
cpu_ops = [ev for ev in cpu_ops if float(ev.get("ts", 0)) > cutoff_ts]
wall_clock_us = 0
if cpu_ops:
ts_sorted = sorted(cpu_ops, key=lambda e: float(e.get("ts", 0)))
min_cpu_ts = float(ts_sorted[0].get("ts", 0))
max_cpu_end = max(
float(e.get("ts", 0)) + float(e.get("dur", 0)) for e in ts_sorted
)
wall_clock_us = max_cpu_end - min_cpu_ts
return {
"total_cuda_us": total_cuda,
"n_steps": n_steps_profiled,
"categories": dict(cat_stats),
"kernel_stats": dict(kernel_stats),
"n_events": len(cuda_events),
"fill_by_size": dict(fill_by_size),
"wall_clock_us": wall_clock_us,
}
def print_trace_analysis(result, label=""):
total = result["total_cuda_us"]
n = result["n_steps"]
if label:
print(f"\n{'=' * 75}")
print(f" {label}")
print(f"{'=' * 75}")
print(
f"\n CUDA kernel time: {total / 1e6:.2f}s over {n} steps "
f"(~{total / n / 1e6:.2f}s/step)"
)
if result.get("wall_clock_us"):
wc = result["wall_clock_us"]
print(
f" Wall clock span: {wc / 1e6:.2f}s over {n} steps "
f"(~{wc / n / 1e6:.2f}s/step)"
)
print(f"\n {'Category':<40} {'Total':>9} {'%':>6} {'Count':>7} {'Per step':>9}")
print(f" {'-' * 75}")
for cat, info in sorted(
result["categories"].items(), key=lambda x: x[1]["total_us"], reverse=True
):
pct = info["total_us"] / total * 100
ps = info["total_us"] / n / 1000
print(
f" {cat:<40} {info['total_us'] / 1000:>8.1f}ms {pct:>5.1f}% "
f"{info['count']:>7} {ps:>7.1f}ms"
)
print("\n Top 15 individual kernels:")
for name, info in sorted(
result["kernel_stats"].items(), key=lambda x: x[1]["total_us"], reverse=True
)[:15]:
pct = info["total_us"] / total * 100
avg = info["total_us"] / info["count"] / 1000
print(
f" {name[:62]:<62} {info['total_us'] / 1000:>7.1f}ms "
f"({pct:>4.1f}%) x{info['count']:<5} avg={avg:.3f}ms"
)
# Fill/zero_ breakdown
fill_data = result.get("fill_by_size", {})
if fill_data:
total_fill_us = sum(v["total_us"] for v in fill_data.values())
if total_fill_us > 0:
print(
f"\n Fill/zero_ kernel breakdown by tensor size "
f"(total: {total_fill_us / 1000:.1f}ms, "
f"{total_fill_us / total * 100:.1f}% of CUDA time):"
)
print(f" {'Input Dims':<50} {'Time':>9} {'Count':>7}")
print(f" {'-' * 70}")
for dims, info in sorted(
fill_data.items(), key=lambda x: x[1]["total_us"], reverse=True
)[:10]:
dims_str = str(dims)[:50]
print(
f" {dims_str:<50} {info['total_us'] / 1000:>7.1f}ms "
f"{info['count']:>7}"
)
def print_summary(result, mem_result=None):
"""Print optimization recommendations and summary."""
total = result["total_cuda_us"]
n = result["n_steps"]
print(f"\n{'=' * 75}")
print(" SUMMARY & RECOMMENDATIONS")
print(f"{'=' * 75}")
# Estimated per-step wall clock
wc = result.get("wall_clock_us", 0)
if wc > 0:
print(f"\n Estimated per-step wall clock: {wc / n / 1e6:.2f}s")
else:
print(f"\n Estimated per-step CUDA time: {total / n / 1e6:.2f}s")
# Memory utilization
if mem_result:
reserved = mem_result["total_reserved"]
allocated = mem_result["total_allocated"]
if reserved > 0:
util_pct = allocated / reserved * 100
print(
f" Memory utilization: {util_pct:.1f}% "
f"({allocated / 1e9:.2f} / {reserved / 1e9:.2f} GB)"
)
# Build recommendations
recommendations = []
cats_sorted = sorted(
result["categories"].items(), key=lambda x: x[1]["total_us"], reverse=True
)
# Check top category
if cats_sorted:
top_cat, top_info = cats_sorted[0]
top_pct = top_info["total_us"] / total * 100
if top_pct > 30:
if "GEMM" in top_cat or "CUTLASS" in top_cat:
recommendations.append(
f"GEMM/matmul dominates ({top_pct:.0f}%). "
f"Consider FP8 training, LoRA (fewer params), "
f"or smaller batch size to reduce compute."
)
elif "Attention" in top_cat:
recommendations.append(
f"Attention dominates ({top_pct:.0f}%). "
f"Ensure FlashAttention v2/v3 is active. "
f"Consider reducing sequence length or using sliding window."
)
elif "Fill" in top_cat or "Elementwise" in top_cat:
recommendations.append(
f"Elementwise/Fill ops dominate ({top_pct:.0f}%). "
f"Enable kernel fusion (Liger, torch.compile) to reduce "
f"memory-bound elementwise operations."
)
elif "ScatterMoE" in top_cat:
recommendations.append(
f"MoE routing/scatter dominates ({top_pct:.0f}%). "
f"Check expert count and capacity factor. "
f"Verify ScatterMoE kernels are using optimal block sizes."
)
else:
recommendations.append(
f"'{top_cat}' dominates ({top_pct:.0f}%). "
f"Focus optimization efforts here first."
)
# Check memory ops
for cat, info in cats_sorted:
pct = info["total_us"] / total * 100
if "Memory" in cat and pct > 10:
recommendations.append(
f"Memory ops are {pct:.0f}% of CUDA time. "
f"Consider gradient checkpointing, reducing activation "
f"recomputation, or pinned memory for data loading."
)
break
# Check fill overhead
fill_data = result.get("fill_by_size", {})
total_fill_us = sum(v["total_us"] for v in fill_data.values())
fill_pct = total_fill_us / total * 100 if total > 0 else 0
if fill_pct > 5:
recommendations.append(
f"Fill/zero_ kernels consume {fill_pct:.1f}% of CUDA time. "
f"Large zero-fills suggest excessive tensor allocation. "
f"Consider reusing buffers or lazy initialization."
)
# Check fragmentation
if mem_result and mem_result.get("fragmentation_pct", 0) > 20:
frag = mem_result["fragmentation_pct"]
recommendations.append(
f"Memory fragmentation is {frag:.0f}%. "
f"Use PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True "
f"or max_split_size_mb to reduce fragmentation."
)
# GPU utilization hint from wall clock vs CUDA time
if wc > 0 and total > 0:
gpu_util = total / wc * 100
if gpu_util < 50:
recommendations.append(
f"GPU utilization is ~{gpu_util:.0f}% (CUDA time vs wall clock). "
f"CPU-side bottleneck likely. Profile data loading, "
f"reward computation, or weight sync overhead."
)
# Print top 3
print("\n Top optimization recommendations:")
if not recommendations:
print(" No major issues detected.")
for i, rec in enumerate(recommendations[:3], 1):
print(f" {i}. {rec}")
print()
def compare_traces(before, after):
print(f"\n{'=' * 75}")
print(" COMPARISON")
print(f"{'=' * 75}")
tb = before["total_cuda_us"]
ta = after["total_cuda_us"]
nb = before["n_steps"]
na = after["n_steps"]
print(
f"\n Per-step CUDA time: {tb / nb / 1e6:.2f}s -> {ta / na / 1e6:.2f}s "
f"({(ta / na - tb / nb) / (tb / nb) * 100:+.1f}%)"
)
all_cats = sorted(
set(list(before["categories"]) + list(after["categories"])),
key=lambda c: before["categories"].get(c, {"total_us": 0})["total_us"],
reverse=True,
)
print(
f"\n {'Category':<35} {'Before/step':>11} {'After/step':>11} "
f"{'Delta':>9} {'Speedup':>8}"
)
print(f" {'-' * 78}")
for cat in all_cats:
b = before["categories"].get(cat, {"total_us": 0})["total_us"] / nb
a = after["categories"].get(cat, {"total_us": 0})["total_us"] / na
delta = a - b
speedup = b / a if a > 0 else float("inf")
if b > tb / nb * 0.003 or a > ta / na * 0.003:
print(
f" {cat:<35} {b / 1000:>9.1f}ms {a / 1000:>9.1f}ms "
f"{delta / 1000:>+8.1f}ms {speedup:>7.2f}x"
)
# ---- Allocation churn attribution -----------------------------------------
def _short_path(p):
"""Shorten a file path for display."""
if "/site-packages/" in p:
return "..." + p.split("/site-packages/")[1]
if "/axolotl/src/" in p:
return "axolotl/" + p.split("/axolotl/src/")[1]
if "/axolotl/" in p:
return "axolotl/" + p.split("/axolotl/")[1]
if len(p) > 60:
parts = p.split("/")
return ".../" + "/".join(parts[-3:])
return p
def _get_top_python_frame(frames):
"""Get the first meaningful Python frame from a trace frame list.
Skips internal torch/cuda/autograd frames to find the user-level code
that triggered the allocation. Falls back to first Python frame.
"""
skip_patterns = [
"torch/autograd/",
"torch/utils/checkpoint",
"torch/nn/modules/module.py",
"torch/_dynamo/",
"torch/_compile",
"torch/_ops",
"torch/_functorch/",
"torch/_inductor/",
"torch/library",
"torch/cuda",
"torch/_C",
"<frozen",
"<unknown>",
"fire/core.py",
"runpy",
]
py_frames = [
f
for f in frames
if isinstance(f, dict) and f.get("filename", "").endswith(".py")
]
for fr in py_frames:
fname = fr.get("filename", "")
if not any(p in fname for p in skip_patterns):
return fr
return py_frames[0] if py_frames else None
def _categorize_source(fname, funcname):
"""Assign a human-readable category to an allocation source."""
if "checkpoint" in fname.lower() or "backward" in funcname:
return "Gradient checkpoint recompute"
if "bitsandbytes" in fname:
return "BnB dequantization"
if "scattermoe" in fname or "scatter" in funcname.lower():
return "ScatterMoE LoRA"
if "adam" in funcname.lower() or "optim" in fname.lower() or "inductor" in fname:
return "Optimizer"
if "norm" in funcname.lower():
return "LayerNorm"
if "fla/" in fname or "gated_delta" in fname:
return "FLA linear attention"
return "Other"
def analyze_allocation_churn(snapshot):
"""Group allocation churn by Python source attribution.
For each major churn size, identifies which Python code creates
the tensors (gradient checkpointing, BnB dequant, LoRA conversion, etc).
"""
traces = snapshot.get("device_traces", [[]])[0]
if not traces:
return None
# Find the top churn sizes
size_counts = defaultdict(int)
for ev in traces:
if ev.get("action") == "alloc":
size_counts[ev["size"]] += 1
# Top 5 sizes by total churn (count × size)
top_sizes = sorted(size_counts.items(), key=lambda x: x[0] * x[1], reverse=True)[:5]
results = {}
for target_sz, total_count in top_sizes:
mb = target_sz / 1e6
if mb < 1.0:
continue
frame_groups = defaultdict(lambda: {"count": 0})
for ev in traces:
if ev.get("action") == "alloc" and ev.get("size") == target_sz:
top = _get_top_python_frame(ev.get("frames", []))
if top:
key = (top["filename"], top["name"], top.get("line", 0))
else:
key = ("<unknown>", "<no attribution>", 0)
frame_groups[key]["count"] += 1
# Group into categories
categories = defaultdict(lambda: {"count": 0, "sources": []})
for (fname, funcname, lineno), info in frame_groups.items():
cat = _categorize_source(fname, funcname)
categories[cat]["count"] += info["count"]
categories[cat]["sources"].append(
(funcname, _short_path(fname), lineno, info["count"])
)
results[target_sz] = {
"total_count": total_count,
"total_churn_gb": total_count * target_sz / 1e9,
"categories": dict(categories),
}
return results
def print_allocation_churn(results, label=""):
if not results:
return
if label:
print(f"\n{'=' * 75}")
print(f" {label}")
print(f"{'=' * 75}")
else:
print("\n ALLOCATION CHURN BY SOURCE")
for target_sz in sorted(
results, key=lambda s: results[s]["total_churn_gb"], reverse=True
):
info = results[target_sz]
mb = target_sz / 1e6
print(
f"\n {mb:.1f}MB × {info['total_count']:,} = {info['total_churn_gb']:.0f}GB churn:"
)
for cat, cinfo in sorted(
info["categories"].items(), key=lambda x: x[1]["count"], reverse=True
):
pct = cinfo["count"] / info["total_count"] * 100
print(f" {pct:>4.0f}% {cat} ({cinfo['count']} allocs)")
for funcname, short, lineno, cnt in sorted(
cinfo["sources"], key=lambda x: x[3], reverse=True
)[:3]:
print(f" {funcname:35s} {short}:{lineno} (x{cnt})")
# ---- CPU overhead analysis ------------------------------------------------
def analyze_cpu_overhead(events):
"""Analyze memcpy, checkpoint recomputation, and GPU utilization from trace events."""
cuda_total_us = 0
cuda_count = 0
memcpy_stats = defaultdict(lambda: {"dur_us": 0, "count": 0, "bytes": 0})
checkpoint_us = 0
checkpoint_count = 0
min_ts = float("inf")
max_ts_end = 0
for ev in events:
if ev.get("ph") != "X":
continue
cat = ev.get("cat", "")
name = ev.get("name", "")
dur = float(ev.get("dur", 0))
ts = float(ev.get("ts", 0))
end = ts + dur
if ts < min_ts:
min_ts = ts
if end > max_ts_end:
max_ts_end = end
if cat == "kernel":
cuda_total_us += dur
cuda_count += 1
elif cat == "gpu_memcpy":
if "DtoH" in name:
direction = "GPU→CPU (offload)"
elif "HtoD" in name:
direction = "CPU→GPU (reload)"
elif "DtoD" in name:
direction = "GPU→GPU"
else:
direction = name
memcpy_stats[direction]["dur_us"] += dur
memcpy_stats[direction]["count"] += 1
nbytes = ev.get("args", {}).get("Bytes", 0)
if nbytes:
memcpy_stats[direction]["bytes"] += int(nbytes)
elif cat in ("cpu_op", "python_function"):
nl = name.lower()
if "checkpoint" in nl or "recompute" in nl:
checkpoint_us += dur
checkpoint_count += 1
wall_us = max_ts_end - min_ts if max_ts_end > min_ts else 0
return {
"wall_clock_us": wall_us,
"cuda_total_us": cuda_total_us,
"cuda_kernel_count": cuda_count,
"memcpy_stats": dict(memcpy_stats),
"checkpoint_us": checkpoint_us,
"checkpoint_count": checkpoint_count,
}
def print_cpu_overhead(result, n_steps=2, label=""):
if not result:
return
if label:
print(f"\n{'=' * 75}")
print(f" {label}")
print(f"{'=' * 75}")
else:
print("\n CPU OVERHEAD ANALYSIS")
wall = result["wall_clock_us"]
cuda = result["cuda_total_us"]
count = result["cuda_kernel_count"]
gpu_util = cuda / wall * 100 if wall > 0 else 0
cpu_gap = wall - cuda
print(
f"\n Wall clock: {wall / 1e6:.2f}s (~{wall / n_steps / 1e6:.2f}s/step)"
)
print(f" CUDA kernel time: {cuda / 1e6:.2f}s ({count:,} kernels)")
print(f" GPU utilization: {gpu_util:.1f}%")
print(f" CPU overhead: {cpu_gap / 1e6:.2f}s ({100 - gpu_util:.1f}%)")
memcpy = result["memcpy_stats"]
if memcpy:
total_memcpy = sum(v["dur_us"] for v in memcpy.values())
total_bytes = sum(v["bytes"] for v in memcpy.values())
print(
f"\n Memory transfers: {total_memcpy / 1e6:.3f}s "
f"({total_bytes / 1e9:.2f}GB, {total_memcpy / wall * 100:.1f}% of wall)"
)
for direction, info in sorted(
memcpy.items(), key=lambda x: x[1]["dur_us"], reverse=True
):
gb = info["bytes"] / 1e9
print(
f" {direction:30s} {info['dur_us'] / 1e6:.3f}s "
f"x{info['count']:>5} {gb:.2f}GB"
)
if result["checkpoint_count"] > 0:
print(
f"\n Gradient checkpoint CPU ops: {result['checkpoint_us'] / 1e6:.3f}s "
f"(x{result['checkpoint_count']})"
)
# ---- Memory snapshot analysis ---------------------------------------------
def load_snapshot(path):
"""Load a PyTorch CUDA memory snapshot from a pickle file.
WARNING: This uses pickle.load() which can execute arbitrary code.
Only load snapshot files that you generated yourself from trusted
training runs. Never load snapshots from untrusted sources.
"""
snap_file = Path(path) / "snapshot.pickle" if Path(path).is_dir() else Path(path)
if not snap_file.exists():
return None
print(f"Loading {snap_file.name} ({snap_file.stat().st_size / 1e6:.0f} MB)...")
with open(snap_file, "rb") as f:
return pickle.load(f) # nosec B301
def _extract_python_frames(snapshot):
"""Extract Python source attribution from snapshot blocks with stacks='all'.
The snapshot structure (when stacks='all') stores frames in:
segments[i].blocks[j].history[k].frames = [(filename, lineno, name), ...]
Returns a dict mapping (filename, function_name) -> {"bytes": int, "count": int}
"""
source_allocs = defaultdict(lambda: {"bytes": 0, "count": 0})
for seg in snapshot.get("segments", []):
for block in seg.get("blocks", []):
if block.get("state") != "active_allocated":
continue
size = block.get("size", 0)
history = block.get("history", [])
if not history:
continue
# Use the most recent allocation history entry
last_hist = history[-1]
frames = last_hist.get("frames", [])
# Find the first Python frame (skip C++ frames)
# Frames are tuples: (filename, lineno, name)
attributed = False
for frame in frames:
if not isinstance(frame, (list, tuple)) or len(frame) < 3:
continue
filename, lineno, funcname = frame[0], frame[1], frame[2]
# Skip internal torch/cuda frames to find user-level attribution
fname_str = str(filename)
if any(
skip in fname_str
for skip in [
"torch/cuda",
"torch/_C",
"torch/utils",
"cuda/memory.py",
"<unknown>",
]
):
continue
key = (fname_str, funcname, lineno)
source_allocs[key]["bytes"] += size
source_allocs[key]["count"] += 1
attributed = True
break
# If no user frame found, use first available frame
if not attributed and frames:
frame = frames[0]
if isinstance(frame, (list, tuple)) and len(frame) >= 3:
key = (str(frame[0]), str(frame[2]), frame[1])
source_allocs[key]["bytes"] += size
source_allocs[key]["count"] += 1
return dict(source_allocs)
def _extract_source_file_summary(source_allocs):
"""Aggregate per-frame allocations to per-file level."""
file_allocs = defaultdict(lambda: {"bytes": 0, "count": 0, "functions": set()})
for (filename, funcname, _lineno), info in source_allocs.items():
file_allocs[filename]["bytes"] += info["bytes"]
file_allocs[filename]["count"] += info["count"]
file_allocs[filename]["functions"].add(funcname)
return dict(file_allocs)
def analyze_snapshot(snapshot):
segments = snapshot.get("segments", [])
total_reserved = sum(s.get("total_size", 0) for s in segments)
total_allocated = sum(s.get("allocated_size", 0) for s in segments)
# Active blocks
active_blocks = []
for seg in segments:
for block in seg.get("blocks", []):
if block.get("state") == "active_allocated":
active_blocks.append(block.get("size", 0))
# Allocation churn from trace
trace = snapshot.get("device_traces", [[]])[0]
size_counts = defaultdict(lambda: {"count": 0, "total": 0})
for ev in trace:
if ev.get("action") == "alloc":
sz = ev.get("size", 0)
size_counts[sz]["count"] += 1
size_counts[sz]["total"] += sz
# Python frame attribution
source_allocs = _extract_python_frames(snapshot)
file_summary = _extract_source_file_summary(source_allocs)
return {
"total_reserved": total_reserved,
"total_allocated": total_allocated,
"fragmentation_pct": (total_reserved - total_allocated) / total_reserved * 100
if total_reserved > 0
else 0,
"n_segments": len(segments),
"n_active_blocks": len(active_blocks),
"active_bytes": sum(active_blocks),
"largest_active": sorted(active_blocks, reverse=True)[:10],
"alloc_churn": dict(
sorted(size_counts.items(), key=lambda x: x[1]["total"], reverse=True)[:15]
),
"n_trace_events": len(trace),
"source_allocs": source_allocs,
"file_summary": file_summary,
}
def print_memory_analysis(result, label=""):
if label:
print(f"\n{'=' * 75}")
print(f" {label}")
print(f"{'=' * 75}")
reserved = result["total_reserved"]
allocated = result["total_allocated"]
print(f"\n Reserved: {reserved / 1e9:.2f} GB")
print(f" Allocated: {allocated / 1e9:.2f} GB")
print(f" Utilization: {allocated / reserved * 100:.1f}%" if reserved > 0 else "")
print(f" Fragmentation: {result['fragmentation_pct']:.1f}%")
print(
f" Segments: {result['n_segments']}, Active blocks: {result['n_active_blocks']}"
)
print("\n Largest active allocations:")
for sz in result["largest_active"]:
print(f" {sz / 1e6:>10.1f} MB")
# Python source file attribution
file_summary = result.get("file_summary", {})
if file_summary:
print("\n Top allocations by source file:")
print(f" {'Source file':<55} {'Alloc':>10} {'Count':>7}")
print(f" {'-' * 75}")
for fname, info in sorted(
file_summary.items(), key=lambda x: x[1]["bytes"], reverse=True
)[:15]:
# Shorten path for display
short = fname
if len(short) > 55:
parts = short.split("/")
# Keep last 3 path components
short = ".../" + "/".join(parts[-3:])
if len(short) > 55:
short = short[:52] + "..."
funcs = ", ".join(sorted(info["functions"])[:3])
sz = info["bytes"]
unit = "MB"
val = sz / 1e6
if val >= 1000:
unit = "GB"
val = sz / 1e9
print(f" {short:<55} {val:>8.1f}{unit} {info['count']:>7}")
if funcs:
print(f" functions: {funcs[:70]}")
# Top allocations by function (more granular)
source_allocs = result.get("source_allocs", {})
if source_allocs:
print("\n Top allocations by function (with line numbers):")
print(f" {'Function':<35} {'File:Line':<35} {'Size':>10}")
print(f" {'-' * 82}")
for (fname, funcname, lineno), info in sorted(
source_allocs.items(), key=lambda x: x[1]["bytes"], reverse=True
)[:15]:
# Shorten filename
short_file = fname
parts = short_file.split("/")
if len(parts) > 2:
short_file = "/".join(parts[-2:])
loc = f"{short_file}:{lineno}"
if len(loc) > 35:
loc = "..." + loc[-32:]
sz = info["bytes"]
if sz >= 1e9:
sz_str = f"{sz / 1e9:.2f}GB"
else:
sz_str = f"{sz / 1e6:.1f}MB"
print(f" {funcname:<35} {loc:<35} {sz_str:>10}")
print("\n Top allocation churn (alloc count x size):")
print(f" {'Size':>12} {'Count':>8} {'Total churned':>14}")
print(f" {'-' * 38}")
for sz, info in result["alloc_churn"].items():
if sz >= 1e6:
print(
f" {sz / 1e6:>10.1f}MB {info['count']:>8} "
f"{info['total'] / 1e9:>12.2f}GB"
)
# ---- Peak memory timeline from trace events --------------------------------
def analyze_peak_memory(snapshot):
"""Walk through device_traces chronologically to find peak concurrent memory usage.
The snapshot's segment data only captures end-of-step state. The device_traces
record every alloc/free, letting us reconstruct peak usage and identify which
allocation sources were live at that moment.
"""
traces = snapshot.get("device_traces", [[]])[0]
if not traces:
return None
current = 0
peak = 0
peak_idx = 0
live_allocs = {} # addr -> (size, frames)
peak_live = {}
for i, ev in enumerate(traces):
action = ev.get("action")
addr = ev.get("addr", 0)
size = ev.get("size", 0)
if action == "alloc":
current += size
live_allocs[addr] = (size, ev.get("frames", []))
if current > peak:
peak = current
peak_idx = i
peak_live = dict(live_allocs)
elif action == "free_requested":
if addr in live_allocs:
current -= live_allocs[addr][0]
del live_allocs[addr]
# Categorize allocations at peak
peak_categories = defaultdict(lambda: {"bytes": 0, "count": 0})
for _addr, (size, frames) in peak_live.items():
top = _get_top_python_frame(frames)
if top:
cat = _categorize_source(top["filename"], top["name"])
else:
cat = "Unknown"
peak_categories[cat]["bytes"] += size
peak_categories[cat]["count"] += 1
return {
"peak_bytes": peak,
"peak_event_idx": peak_idx,
"total_events": len(traces),
"end_bytes": current,
"peak_categories": dict(peak_categories),
}
def print_peak_memory(result, mem_result=None, label=""):
if not result:
return
if label:
print(f"\n{'=' * 75}")
print(f" {label}")
print(f"{'=' * 75}")
peak_gb = result["peak_bytes"] / 1e9
end_gb = result["end_bytes"] / 1e9
# The device_traces only record allocations AFTER profiling starts.
# Model weights and other persistent allocations are not tracked.
# We can estimate the persistent baseline from snapshot allocated - peak_traced.
persistent_gb = 0
if mem_result:
persistent_gb = mem_result["total_allocated"] / 1e9 - end_gb
total_peak_gb = persistent_gb + peak_gb
print(
f"\n Profiled peak (transient): {peak_gb:.2f} GB "
f"(at event {result['peak_event_idx']:,} / {result['total_events']:,})"
)
if persistent_gb > 0:
print(
f" Persistent baseline: {persistent_gb:.2f} GB "
f"(model + optimizer, allocated before profiling)"
)
print(f" Estimated total peak: {total_peak_gb:.2f} GB")
print(f" Transient headroom: {peak_gb - end_gb:.2f} GB above end-of-trace")
cats = result.get("peak_categories", {})
if cats:
print("\n Allocations live at peak:")
print(f" {'Category':<35} {'Size':>10} {'Count':>7}")
print(f" {'-' * 55}")
for cat, info in sorted(
cats.items(), key=lambda x: x[1]["bytes"], reverse=True
):
sz = info["bytes"]
if sz >= 1e9:
sz_str = f"{sz / 1e9:.2f} GB"
else:
sz_str = f"{sz / 1e6:.1f} MB"
print(f" {cat:<35} {sz_str:>10} {info['count']:>7}")
# ---- Fragmentation diagnosis -----------------------------------------------
def analyze_fragmentation(snapshot):
"""Analyze segment-level memory layout to explain fragmentation.
Examines each CUDA segment for inactive (freed but unreturned) blocks,
pinned small allocations that prevent segment merging, and the overall
segment size distribution.
"""
segments = snapshot.get("segments", [])
if not segments:
return None
total_reserved = 0
total_allocated = 0
total_inactive = 0
segment_sizes = []
inactive_gaps = [] # (gap_size, segment_size, active_around)
pinned_fragments = [] # small active blocks surrounded by inactive
for seg in segments:
seg_size = seg.get("total_size", 0)
total_reserved += seg_size
segment_sizes.append(seg_size)
blocks = seg.get("blocks", [])
seg_active = 0
seg_inactive = 0
for bi, block in enumerate(blocks):
bsize = block.get("size", 0)
if block.get("state") == "active_allocated":
seg_active += bsize
total_allocated += bsize
# Check if this small block is surrounded by inactive
if bsize < 2 * 1024 * 1024: # < 2MB
prev_inactive = bi > 0 and blocks[bi - 1].get("state") == "inactive"
next_inactive = (
bi < len(blocks) - 1
and blocks[bi + 1].get("state") == "inactive"
)
if prev_inactive and next_inactive:
pinned_fragments.append((bsize, seg_size))
elif block.get("state") == "inactive":
seg_inactive += bsize
total_inactive += bsize
inactive_gaps.append((bsize, seg_size))
# Classify segment sizes
size_buckets = defaultdict(lambda: {"count": 0, "total": 0})
for sz in segment_sizes:
if sz >= 1024 * 1024 * 1024:
bucket = ">=1 GB"
elif sz >= 256 * 1024 * 1024:
bucket = "256MB-1GB"
elif sz >= 64 * 1024 * 1024:
bucket = "64-256MB"
elif sz >= 2 * 1024 * 1024:
bucket = "2-64MB"
else:
bucket = "<2MB"
size_buckets[bucket]["count"] += 1
size_buckets[bucket]["total"] += sz
# Large inactive gaps that could be reclaimed
inactive_gaps.sort(key=lambda x: x[0], reverse=True)
return {
"total_reserved": total_reserved,
"total_allocated": total_allocated,
"total_inactive": total_inactive,
"n_segments": len(segments),
"segment_size_buckets": dict(size_buckets),
"large_inactive_gaps": inactive_gaps[:20],
"pinned_fragments": len(pinned_fragments),
"expandable_segments_would_help": (
total_inactive > 0.1 * total_reserved and len(segments) > 10
),
}
def print_fragmentation(result, gpu_capacity_gb=None, label=""):
if not result:
return
if label:
print(f"\n{'=' * 75}")
print(f" {label}")
print(f"{'=' * 75}")
reserved = result["total_reserved"]
allocated = result["total_allocated"]
inactive = result["total_inactive"]
frag_pct = inactive / reserved * 100 if reserved > 0 else 0
print(
f"\n Reserved: {reserved / 1e9:.2f} GB across {result['n_segments']} segments"
)
print(f" Allocated: {allocated / 1e9:.2f} GB")
print(f" Inactive: {inactive / 1e9:.2f} GB ({frag_pct:.1f}% fragmentation)")
if result["pinned_fragments"] > 0:
print(
f" Pinned small blocks (<2MB between inactive): "
f"{result['pinned_fragments']} (prevent segment merging)"
)
# Segment size distribution
print("\n Segment size distribution:")
bucket_order = [">=1 GB", "256MB-1GB", "64-256MB", "2-64MB", "<2MB"]
for bucket in bucket_order:
info = result["segment_size_buckets"].get(bucket)
if info:
print(
f" {bucket:<12} {info['count']:>4} segments "
f"{info['total'] / 1e9:>6.2f} GB"
)
# Largest inactive gaps
gaps = result.get("large_inactive_gaps", [])
if gaps:
print("\n Largest inactive gaps (freed but unreclaimable):")
shown = 0
for gap_sz, seg_sz in gaps:
if gap_sz >= 32 * 1024 * 1024 and shown < 10:
print(
f" {gap_sz / 1e6:>8.0f} MB gap in {seg_sz / 1e6:.0f} MB segment"
)
shown += 1
# OOM risk assessment
if gpu_capacity_gb:
gpu_bytes = gpu_capacity_gb * 1e9
usable = gpu_bytes - (reserved - allocated) # capacity minus fragmented waste
print(f"\n OOM Risk Assessment (GPU: {gpu_capacity_gb:.1f} GB):")
print(
f" Usable capacity: {usable / 1e9:.2f} GB "
f"(GPU capacity minus {inactive / 1e9:.2f} GB fragmentation)"
)
headroom = gpu_bytes - reserved
print(f" Current headroom: {headroom / 1e9:.2f} GB")
if headroom < 1.0e9:
print(" ⚠ CRITICAL: <1 GB headroom — high OOM risk!")
elif headroom < 2.0e9:
print(" ⚠ WARNING: <2 GB headroom — moderate OOM risk")
# Recommendation
if result.get("expandable_segments_would_help"):
print("\n → FIX: Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True")
print(" This eliminates segment fragmentation by growing segments in-place,")
print(
f" which would reclaim up to {inactive / 1e9:.1f} GB of wasted memory."
)
# ---- Sequence-length scaling analysis --------------------------------------
def analyze_scaling(mem_a, mem_b, churn_a, churn_b):
"""Compare per-tensor allocation sizes between two runs.
When two profiles differ only in sequence length, this shows which tensor
categories scale with sequence length by comparing the dominant tensor sizes.
Total churn may differ due to different profiling windows, so we focus on
per-tensor size ratios instead.
"""
if not churn_a or not churn_b:
return None
def _cat_sizes(churn):
"""Map category -> {largest_size, total_bytes, count}."""
cat_data = defaultdict(lambda: {"max_size": 0, "sizes": [], "count": 0})
for sz, info in churn.items():
for cat, cinfo in info.get("categories", {}).items():
cat_data[cat]["count"] += cinfo["count"]
cat_data[cat]["sizes"].append((sz, cinfo["count"]))
if sz > cat_data[cat]["max_size"]:
cat_data[cat]["max_size"] = sz
return dict(cat_data)
cats_a = _cat_sizes(churn_a)
cats_b = _cat_sizes(churn_b)
all_cats = sorted(
set(list(cats_a) + list(cats_b)),
key=lambda c: max(
cats_a.get(c, {"max_size": 0})["max_size"],
cats_b.get(c, {"max_size": 0})["max_size"],
),
reverse=True,
)
scaling = []
for cat in all_cats:
a = cats_a.get(cat)
b = cats_b.get(cat)
if not a or not b:
continue
a_max = a["max_size"]
b_max = b["max_size"]
if a_max > 1e6 and b_max > 1e6: # Only compare >1MB tensors
tensor_ratio = b_max / a_max if a_max > 0 else None
scaling.append(
{
"category": cat,
"size_a_mb": a_max / 1e6,
"size_b_mb": b_max / 1e6,
"tensor_ratio": tensor_ratio,
"count_a": a["count"],
"count_b": b["count"],
"scales_with_seqlen": tensor_ratio is not None
and tensor_ratio > 1.05,
}
)
scaling.sort(key=lambda x: x["size_b_mb"], reverse=True)
return scaling
def print_scaling(scaling, label_a="Before", label_b="After", label=""):
if not scaling:
return
if label:
print(f"\n{'=' * 75}")
print(f" {label}")
print(f"{'=' * 75}")
print("\n Per-tensor size comparison (largest tensor per category):")
print(
f" {'Category':<35} {'A size':>10} {'B size':>10} {'Ratio':>7} {'Scales?':>8}"
)
print(f" {'-' * 73}")
for entry in scaling:
ratio_str = f"{entry['tensor_ratio']:.2f}x" if entry["tensor_ratio"] else "N/A"
scales = "YES" if entry["scales_with_seqlen"] else "no"
print(
f" {entry['category']:<35} {entry['size_a_mb']:>8.1f}MB "
f"{entry['size_b_mb']:>8.1f}MB {ratio_str:>7} {scales:>8}"
)
# Summary
seq_scaling = [e for e in scaling if e["scales_with_seqlen"]]
constant = [e for e in scaling if not e["scales_with_seqlen"]]
if seq_scaling:
ratios = [e["tensor_ratio"] for e in seq_scaling if e["tensor_ratio"]]
avg_ratio = sum(ratios) / len(ratios) if ratios else 0
print(f"\n Sequence-length scaling detected ({avg_ratio:.2f}x avg):")
for e in seq_scaling:
print(
f" - {e['category']}: {e['size_a_mb']:.1f}MB -> "
f"{e['size_b_mb']:.1f}MB ({e['tensor_ratio']:.2f}x)"
)
if constant:
print("\n Constant-size categories (do not scale with seq len):")
for e in constant:
print(f" - {e['category']}: {e['size_a_mb']:.1f}MB")
# ---- Main -----------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Analyze axolotl training profiler output",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"path",
help="Path to output directory (containing profiler_trace.json and/or "
"snapshot.pickle) or directly to a trace file. "
"Security note: snapshot.pickle uses pickle deserialization — "
"only use files from your own trusted training runs.",
)
parser.add_argument(
"--compare",
help="Path to second run for A/B comparison. "
"Same security note as path: only use trusted snapshot files.",
)
parser.add_argument(
"--include-warmup",
action="store_true",
help="Include step 0 (warmup/compilation) in timing analysis",
)
parser.add_argument(
"--memory-only",
action="store_true",
help="Only analyze memory snapshot, skip trace",
)
parser.add_argument(
"--quick",
action="store_true",
help="Only load first 2M events for rapid analysis of large traces",
)
parser.add_argument(
"--gpu-gb",
type=float,
default=None,
help="GPU total memory in GB (for OOM risk assessment). "
"Auto-detected if not specified.",
)
args = parser.parse_args()
# Auto-detect GPU capacity if not specified
gpu_capacity_gb = args.gpu_gb
if gpu_capacity_gb is None:
try:
import torch
if torch.cuda.is_available():
gpu_capacity_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
except Exception:
pass
skip = not args.include_warmup
trace_result = None
mem_result = None
events = None
# -- Trace analysis --
if not args.memory_only:
events = load_trace(args.path, quick=args.quick)
if events:
trace_result = analyze_trace(events, skip_warmup=skip)
if trace_result:
print_trace_analysis(trace_result, label=f"Trace: {args.path}")
if args.compare:
events2 = load_trace(args.compare, quick=args.quick)
if events2:
result2 = analyze_trace(events2, skip_warmup=skip)
if result2:
print_trace_analysis(
result2, label=f"Trace: {args.compare}"
)
compare_traces(trace_result, result2)
else:
print(f" No profiler_trace.json found in {args.path}")
# -- CPU overhead analysis (from trace) --
if events and trace_result:
cpu_result = analyze_cpu_overhead(events)
print_cpu_overhead(
cpu_result,
n_steps=trace_result["n_steps"],
label=f"CPU Overhead: {args.path}",
)
# -- Memory analysis --
snapshot = load_snapshot(args.path)
churn_result = None
churn2 = None
if snapshot:
mem_result = analyze_snapshot(snapshot)
print_memory_analysis(mem_result, label=f"Memory: {args.path}")
# Peak memory timeline
peak_result = analyze_peak_memory(snapshot)
print_peak_memory(
peak_result, mem_result=mem_result, label=f"Peak Memory: {args.path}"
)
# Fragmentation diagnosis
frag_result = analyze_fragmentation(snapshot)
print_fragmentation(
frag_result,
gpu_capacity_gb=gpu_capacity_gb,
label=f"Fragmentation: {args.path}",
)
# Allocation churn attribution
churn_result = analyze_allocation_churn(snapshot)
if churn_result:
print_allocation_churn(churn_result, label=f"Allocation Churn: {args.path}")
if args.compare:
snapshot2 = load_snapshot(args.compare)
if snapshot2:
mem2 = analyze_snapshot(snapshot2)
print_memory_analysis(mem2, label=f"Memory: {args.compare}")
# Peak memory for comparison
peak2 = analyze_peak_memory(snapshot2)
print_peak_memory(
peak2, mem_result=mem2, label=f"Peak Memory: {args.compare}"
)
# Fragmentation for comparison
frag2 = analyze_fragmentation(snapshot2)
print_fragmentation(
frag2,
gpu_capacity_gb=gpu_capacity_gb,
label=f"Fragmentation: {args.compare}",
)
churn2 = analyze_allocation_churn(snapshot2)
if churn2:
print_allocation_churn(
churn2, label=f"Allocation Churn: {args.compare}"
)
# Memory comparison summary
print("\n Memory comparison:")
print(
f" Reserved: {mem_result['total_reserved'] / 1e9:.2f} -> "
f"{mem2['total_reserved'] / 1e9:.2f} GB"
)
print(
f" Allocated: {mem_result['total_allocated'] / 1e9:.2f} -> "
f"{mem2['total_allocated'] / 1e9:.2f} GB"
)
print(
f" Frag: {mem_result['fragmentation_pct']:.1f}% -> "
f"{mem2['fragmentation_pct']:.1f}%"
)
if peak_result and peak2:
print(
f" Peak: {peak_result['peak_bytes'] / 1e9:.2f} -> "
f"{peak2['peak_bytes'] / 1e9:.2f} GB"
)
# Scaling analysis
if churn_result and churn2:
scaling = analyze_scaling(mem_result, mem2, churn_result, churn2)
print_scaling(
scaling,
label_a=str(args.path),
label_b=str(args.compare),
label="Allocation Scaling Analysis",
)
elif not args.memory_only:
pass # trace-only is fine
else:
print(f" No snapshot.pickle found in {args.path}")
# -- Summary --
if trace_result:
print_summary(trace_result, mem_result=mem_result)
if __name__ == "__main__":
main()