#!/usr/bin/env python3 """ Axolotl Training Profiler Analyzer =================================== Analyzes PyTorch profiler output from axolotl training runs (profiler_steps config). Produces breakdowns by CUDA kernel category, identifies bottlenecks, and optionally compares two traces (e.g. before/after optimization). Supports both: - profiler_trace.json (torch.profiler Chrome trace -- timing analysis) - snapshot.pickle (torch.cuda.memory._snapshot -- memory analysis) Usage: # Analyze a single trace python analyze_profile.py outputs/qwen35_moe_profile/ # Compare before vs after python analyze_profile.py outputs/before/ --compare outputs/after/ # Include step 0 (warmup/compilation) in analysis python analyze_profile.py outputs/run/ --include-warmup # Memory-only analysis python analyze_profile.py outputs/run/ --memory-only # Quick mode (first 2M events only, for large traces) python analyze_profile.py outputs/run/ --quick """ import argparse import json import pickle # nosec B403 import time from collections import defaultdict from pathlib import Path # ---- Kernel categorization ------------------------------------------------ KERNEL_CATEGORIES = [ # ScatterMoE -- ordered most specific first ( "ScatterMoE bwd LoRA (split)", ["_group_bwd_lora_split", "_group_bwd_da", "_group_bwd_db"], ), ("ScatterMoE bwd LoRA (fused)", ["_group_bwd_lora"]), ("ScatterMoE bwd dX", ["_scatter2scatter_lora_dx"]), ("ScatterMoE fwd", ["_scatter2scatter_lora"]), # Quantization ("BnB Dequantization", ["dequantize", "kDequantizeBlockwise"]), # Attention ("Flash Attention", ["flash", "fmha"]), # Loss ("CCE Loss", ["_cce_"]), # LoRA fused kernels (autograd.Function based) ("LoRA QKV Kernel", ["lora_qkv"]), ("LoRA O Kernel", ["lora_o"]), ("LoRA MLP Kernel", ["lora_mlp"]), # LoRA activation kernels (SwiGLU/GEGLU Triton kernels) ("LoRA Activation (SwiGLU/GEGLU)", ["swiglu", "geglu"]), # DoRA weight norm ("DoRA Weight Norm", ["linalg_norm", "dora_scale"]), # Compute ("GEMM/CUTLASS", ["cutlass", "gemm", "gemv", "cublas"]), ("Triton (norms etc)", ["triton"]), ("Conv1d", ["conv1d", "causal_conv"]), # Optimizer ("Optimizer", ["adam", "optim"]), # Dtype conversion (fp32→bf16 LoRA matrix casts etc) ("Dtype Conversion", ["_to_copy", "to_copy"]), # Memory ("Elementwise/Fill", ["fill", "elementwise", "cast", "copy_kernel"]), ("Memory ops", ["memcpy", "memset"]), # Routing ("TopK/Sort", ["topk", "sort"]), ("Index/Gather/Scatter", ["index", "gather", "scatter"]), ] # Categories to keep when pre-filtering during streaming load _KEEP_CATS = {"kernel", "gpu_memcpy", "cpu_op", "python_function", "ac2g", "Runtime"} def categorize_kernel(name): nl = name.lower() for cat_name, patterns in KERNEL_CATEGORIES: if any(p in nl for p in patterns): return cat_name return "Other" # ---- Trace loading --------------------------------------------------------- def _try_ijson_load(trace_file, quick=False, max_events=2_000_000): """Stream-parse trace JSON with ijson. Returns filtered events list.""" try: import ijson except ImportError: return None events = [] count = 0 with open(trace_file, "rb") as f: for ev in ijson.items(f, "traceEvents.item"): count += 1 cat = ev.get("cat", "") # Pre-filter: only keep categories we care about if cat in _KEEP_CATS or ev.get("ph") == "M": events.append(ev) if quick and count >= max_events: print(f" --quick: stopped after {count:,} events") break return events, count def load_trace(path, quick=False): trace_file = ( Path(path) / "profiler_trace.json" if Path(path).is_dir() else Path(path) ) if not trace_file.exists(): return None size_gb = trace_file.stat().st_size / 1e9 print(f"Loading {trace_file.name} ({size_gb:.1f} GB)...") t0 = time.monotonic() # Try streaming parser first for large files if size_gb > 0.5: result = _try_ijson_load(trace_file, quick=quick) if result is not None: events, total_count = result elapsed = time.monotonic() - t0 print( f" {total_count:,} total events, {len(events):,} kept " f"(streamed in {elapsed:.1f}s)" ) return events # Fallback: standard json.load with open(trace_file) as f: data = json.load(f) all_events = data.get("traceEvents", []) elapsed = time.monotonic() - t0 if quick: all_events = all_events[:2_000_000] print(f" --quick: limited to first {len(all_events):,} events") # Pre-filter events = [ ev for ev in all_events if ev.get("cat", "") in _KEEP_CATS or ev.get("ph") == "M" ] print( f" {len(all_events):,} total events, {len(events):,} kept " f"(loaded in {elapsed:.1f}s)" ) return events # ---- Trace analysis ------------------------------------------------------- def _estimate_n_steps(cuda_events): """Estimate the number of training steps from CUDA event timestamps. Detects step boundaries by looking for large gaps (>2x median gap) in the sorted timestamp sequence of CUDA kernels. """ if len(cuda_events) < 100: return 1 timestamps = sorted(float(ev.get("ts", 0)) for ev in cuda_events) # Compute gaps between consecutive events gaps = [timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)] if not gaps: return 1 median_gap = sorted(gaps)[len(gaps) // 2] # A step boundary has a gap much larger than the median inter-kernel gap threshold = max(median_gap * 50, 100_000) # at least 100ms n_boundaries = sum(1 for g in gaps if g > threshold) return max(n_boundaries + 1, 1) def analyze_trace(events, skip_warmup=True): cuda_events = [ ev for ev in events if ev.get("ph") == "X" and ev.get("cat") in ("kernel", "gpu_memcpy") ] if not cuda_events: print(" No CUDA kernel events found!") return None cutoff_ts = None if skip_warmup and len(cuda_events) > 1000: timestamps = sorted(set(float(ev.get("ts", 0)) for ev in cuda_events)) min_ts, max_ts = timestamps[0], timestamps[-1] total_span = max_ts - min_ts # Step 0 is warmup (Triton compilation + autotune). It's typically # the slowest step by far. Use 45% of wall-clock as cutoff -- step 0 # usually takes >50% of total time when it includes compilation. cutoff_ts = min_ts + total_span * 0.45 before = len(cuda_events) cuda_events = [ev for ev in cuda_events if float(ev.get("ts", 0)) > cutoff_ts] print(f" Excluding step 0 (warmup): {before:,} -> {len(cuda_events):,} events") n_steps_profiled = _estimate_n_steps(cuda_events) # Aggregate by kernel (cast to float to handle ijson Decimal values) kernel_stats = defaultdict(lambda: {"total_us": 0.0, "count": 0, "max_us": 0.0}) for ev in cuda_events: name = ev.get("name", "unknown") dur = float(ev.get("dur", 0)) kernel_stats[name]["total_us"] += dur kernel_stats[name]["count"] += 1 kernel_stats[name]["max_us"] = max(kernel_stats[name]["max_us"], dur) total_cuda = sum(v["total_us"] for v in kernel_stats.values()) # Group by category cat_stats = defaultdict(lambda: {"total_us": 0, "count": 0}) for name, info in kernel_stats.items(): cat = categorize_kernel(name) cat_stats[cat]["total_us"] += info["total_us"] cat_stats[cat]["count"] += info["count"] # Fill/zero_ analysis: find FillFunctor kernels and group by tensor size fill_by_size = defaultdict(lambda: {"total_us": 0, "count": 0}) for ev in cuda_events: name = ev.get("name", "") if "FillFunctor" not in name and "fill" not in name.lower(): continue # Extract Input Dims from args if present args = ev.get("args", {}) input_dims = args.get("Input Dims", args.get("input_dims", "unknown")) if isinstance(input_dims, list): input_dims = str(input_dims) dur = float(ev.get("dur", 0)) fill_by_size[input_dims]["total_us"] += dur fill_by_size[input_dims]["count"] += 1 # CPU op analysis for wall-clock estimation (apply same warmup cutoff) cpu_ops = [ ev for ev in events if ev.get("ph") == "X" and ev.get("cat") in ("cpu_op", "python_function") ] if cutoff_ts is not None: cpu_ops = [ev for ev in cpu_ops if float(ev.get("ts", 0)) > cutoff_ts] wall_clock_us = 0 if cpu_ops: ts_sorted = sorted(cpu_ops, key=lambda e: float(e.get("ts", 0))) min_cpu_ts = float(ts_sorted[0].get("ts", 0)) max_cpu_end = max( float(e.get("ts", 0)) + float(e.get("dur", 0)) for e in ts_sorted ) wall_clock_us = max_cpu_end - min_cpu_ts return { "total_cuda_us": total_cuda, "n_steps": n_steps_profiled, "categories": dict(cat_stats), "kernel_stats": dict(kernel_stats), "n_events": len(cuda_events), "fill_by_size": dict(fill_by_size), "wall_clock_us": wall_clock_us, } def print_trace_analysis(result, label=""): total = result["total_cuda_us"] n = result["n_steps"] if label: print(f"\n{'=' * 75}") print(f" {label}") print(f"{'=' * 75}") print( f"\n CUDA kernel time: {total / 1e6:.2f}s over {n} steps " f"(~{total / n / 1e6:.2f}s/step)" ) if result.get("wall_clock_us"): wc = result["wall_clock_us"] print( f" Wall clock span: {wc / 1e6:.2f}s over {n} steps " f"(~{wc / n / 1e6:.2f}s/step)" ) print(f"\n {'Category':<40} {'Total':>9} {'%':>6} {'Count':>7} {'Per step':>9}") print(f" {'-' * 75}") for cat, info in sorted( result["categories"].items(), key=lambda x: x[1]["total_us"], reverse=True ): pct = info["total_us"] / total * 100 ps = info["total_us"] / n / 1000 print( f" {cat:<40} {info['total_us'] / 1000:>8.1f}ms {pct:>5.1f}% " f"{info['count']:>7} {ps:>7.1f}ms" ) print("\n Top 15 individual kernels:") for name, info in sorted( result["kernel_stats"].items(), key=lambda x: x[1]["total_us"], reverse=True )[:15]: pct = info["total_us"] / total * 100 avg = info["total_us"] / info["count"] / 1000 print( f" {name[:62]:<62} {info['total_us'] / 1000:>7.1f}ms " f"({pct:>4.1f}%) x{info['count']:<5} avg={avg:.3f}ms" ) # Fill/zero_ breakdown fill_data = result.get("fill_by_size", {}) if fill_data: total_fill_us = sum(v["total_us"] for v in fill_data.values()) if total_fill_us > 0: print( f"\n Fill/zero_ kernel breakdown by tensor size " f"(total: {total_fill_us / 1000:.1f}ms, " f"{total_fill_us / total * 100:.1f}% of CUDA time):" ) print(f" {'Input Dims':<50} {'Time':>9} {'Count':>7}") print(f" {'-' * 70}") for dims, info in sorted( fill_data.items(), key=lambda x: x[1]["total_us"], reverse=True )[:10]: dims_str = str(dims)[:50] print( f" {dims_str:<50} {info['total_us'] / 1000:>7.1f}ms " f"{info['count']:>7}" ) def print_summary(result, mem_result=None): """Print optimization recommendations and summary.""" total = result["total_cuda_us"] n = result["n_steps"] print(f"\n{'=' * 75}") print(" SUMMARY & RECOMMENDATIONS") print(f"{'=' * 75}") # Estimated per-step wall clock wc = result.get("wall_clock_us", 0) if wc > 0: print(f"\n Estimated per-step wall clock: {wc / n / 1e6:.2f}s") else: print(f"\n Estimated per-step CUDA time: {total / n / 1e6:.2f}s") # Memory utilization if mem_result: reserved = mem_result["total_reserved"] allocated = mem_result["total_allocated"] if reserved > 0: util_pct = allocated / reserved * 100 print( f" Memory utilization: {util_pct:.1f}% " f"({allocated / 1e9:.2f} / {reserved / 1e9:.2f} GB)" ) # Build recommendations recommendations = [] cats_sorted = sorted( result["categories"].items(), key=lambda x: x[1]["total_us"], reverse=True ) # Check top category if cats_sorted: top_cat, top_info = cats_sorted[0] top_pct = top_info["total_us"] / total * 100 if top_pct > 30: if "GEMM" in top_cat or "CUTLASS" in top_cat: recommendations.append( f"GEMM/matmul dominates ({top_pct:.0f}%). " f"Consider FP8 training, LoRA (fewer params), " f"or smaller batch size to reduce compute." ) elif "Attention" in top_cat: recommendations.append( f"Attention dominates ({top_pct:.0f}%). " f"Ensure FlashAttention v2/v3 is active. " f"Consider reducing sequence length or using sliding window." ) elif "Fill" in top_cat or "Elementwise" in top_cat: recommendations.append( f"Elementwise/Fill ops dominate ({top_pct:.0f}%). " f"Enable kernel fusion (Liger, torch.compile) to reduce " f"memory-bound elementwise operations." ) elif "ScatterMoE" in top_cat: recommendations.append( f"MoE routing/scatter dominates ({top_pct:.0f}%). " f"Check expert count and capacity factor. " f"Verify ScatterMoE kernels are using optimal block sizes." ) else: recommendations.append( f"'{top_cat}' dominates ({top_pct:.0f}%). " f"Focus optimization efforts here first." ) # Check memory ops for cat, info in cats_sorted: pct = info["total_us"] / total * 100 if "Memory" in cat and pct > 10: recommendations.append( f"Memory ops are {pct:.0f}% of CUDA time. " f"Consider gradient checkpointing, reducing activation " f"recomputation, or pinned memory for data loading." ) break # Check fill overhead fill_data = result.get("fill_by_size", {}) total_fill_us = sum(v["total_us"] for v in fill_data.values()) fill_pct = total_fill_us / total * 100 if total > 0 else 0 if fill_pct > 5: recommendations.append( f"Fill/zero_ kernels consume {fill_pct:.1f}% of CUDA time. " f"Large zero-fills suggest excessive tensor allocation. " f"Consider reusing buffers or lazy initialization." ) # Check fragmentation if mem_result and mem_result.get("fragmentation_pct", 0) > 20: frag = mem_result["fragmentation_pct"] recommendations.append( f"Memory fragmentation is {frag:.0f}%. " f"Use PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True " f"or max_split_size_mb to reduce fragmentation." ) # GPU utilization hint from wall clock vs CUDA time if wc > 0 and total > 0: gpu_util = total / wc * 100 if gpu_util < 50: recommendations.append( f"GPU utilization is ~{gpu_util:.0f}% (CUDA time vs wall clock). " f"CPU-side bottleneck likely. Profile data loading, " f"reward computation, or weight sync overhead." ) # Print top 3 print("\n Top optimization recommendations:") if not recommendations: print(" No major issues detected.") for i, rec in enumerate(recommendations[:3], 1): print(f" {i}. {rec}") print() def compare_traces(before, after): print(f"\n{'=' * 75}") print(" COMPARISON") print(f"{'=' * 75}") tb = before["total_cuda_us"] ta = after["total_cuda_us"] nb = before["n_steps"] na = after["n_steps"] print( f"\n Per-step CUDA time: {tb / nb / 1e6:.2f}s -> {ta / na / 1e6:.2f}s " f"({(ta / na - tb / nb) / (tb / nb) * 100:+.1f}%)" ) all_cats = sorted( set(list(before["categories"]) + list(after["categories"])), key=lambda c: before["categories"].get(c, {"total_us": 0})["total_us"], reverse=True, ) print( f"\n {'Category':<35} {'Before/step':>11} {'After/step':>11} " f"{'Delta':>9} {'Speedup':>8}" ) print(f" {'-' * 78}") for cat in all_cats: b = before["categories"].get(cat, {"total_us": 0})["total_us"] / nb a = after["categories"].get(cat, {"total_us": 0})["total_us"] / na delta = a - b speedup = b / a if a > 0 else float("inf") if b > tb / nb * 0.003 or a > ta / na * 0.003: print( f" {cat:<35} {b / 1000:>9.1f}ms {a / 1000:>9.1f}ms " f"{delta / 1000:>+8.1f}ms {speedup:>7.2f}x" ) # ---- Allocation churn attribution ----------------------------------------- def _short_path(p): """Shorten a file path for display.""" if "/site-packages/" in p: return "..." + p.split("/site-packages/")[1] if "/axolotl/src/" in p: return "axolotl/" + p.split("/axolotl/src/")[1] if "/axolotl/" in p: return "axolotl/" + p.split("/axolotl/")[1] if len(p) > 60: parts = p.split("/") return ".../" + "/".join(parts[-3:]) return p def _get_top_python_frame(frames): """Get the first meaningful Python frame from a trace frame list. Skips internal torch/cuda/autograd frames to find the user-level code that triggered the allocation. Falls back to first Python frame. """ skip_patterns = [ "torch/autograd/", "torch/utils/checkpoint", "torch/nn/modules/module.py", "torch/_dynamo/", "torch/_compile", "torch/_ops", "torch/_functorch/", "torch/_inductor/", "torch/library", "torch/cuda", "torch/_C", "", "fire/core.py", "runpy", ] py_frames = [ f for f in frames if isinstance(f, dict) and f.get("filename", "").endswith(".py") ] for fr in py_frames: fname = fr.get("filename", "") if not any(p in fname for p in skip_patterns): return fr return py_frames[0] if py_frames else None def _categorize_source(fname, funcname): """Assign a human-readable category to an allocation source.""" if "checkpoint" in fname.lower() or "backward" in funcname: return "Gradient checkpoint recompute" if "bitsandbytes" in fname: return "BnB dequantization" if "scattermoe" in fname or "scatter" in funcname.lower(): return "ScatterMoE LoRA" if "adam" in funcname.lower() or "optim" in fname.lower() or "inductor" in fname: return "Optimizer" if "norm" in funcname.lower(): return "LayerNorm" if "fla/" in fname or "gated_delta" in fname: return "FLA linear attention" return "Other" def analyze_allocation_churn(snapshot): """Group allocation churn by Python source attribution. For each major churn size, identifies which Python code creates the tensors (gradient checkpointing, BnB dequant, LoRA conversion, etc). """ traces = snapshot.get("device_traces", [[]])[0] if not traces: return None # Find the top churn sizes size_counts = defaultdict(int) for ev in traces: if ev.get("action") == "alloc": size_counts[ev["size"]] += 1 # Top 5 sizes by total churn (count × size) top_sizes = sorted(size_counts.items(), key=lambda x: x[0] * x[1], reverse=True)[:5] results = {} for target_sz, total_count in top_sizes: mb = target_sz / 1e6 if mb < 1.0: continue frame_groups = defaultdict(lambda: {"count": 0}) for ev in traces: if ev.get("action") == "alloc" and ev.get("size") == target_sz: top = _get_top_python_frame(ev.get("frames", [])) if top: key = (top["filename"], top["name"], top.get("line", 0)) else: key = ("", "", 0) frame_groups[key]["count"] += 1 # Group into categories categories = defaultdict(lambda: {"count": 0, "sources": []}) for (fname, funcname, lineno), info in frame_groups.items(): cat = _categorize_source(fname, funcname) categories[cat]["count"] += info["count"] categories[cat]["sources"].append( (funcname, _short_path(fname), lineno, info["count"]) ) results[target_sz] = { "total_count": total_count, "total_churn_gb": total_count * target_sz / 1e9, "categories": dict(categories), } return results def print_allocation_churn(results, label=""): if not results: return if label: print(f"\n{'=' * 75}") print(f" {label}") print(f"{'=' * 75}") else: print("\n ALLOCATION CHURN BY SOURCE") for target_sz in sorted( results, key=lambda s: results[s]["total_churn_gb"], reverse=True ): info = results[target_sz] mb = target_sz / 1e6 print( f"\n {mb:.1f}MB × {info['total_count']:,} = {info['total_churn_gb']:.0f}GB churn:" ) for cat, cinfo in sorted( info["categories"].items(), key=lambda x: x[1]["count"], reverse=True ): pct = cinfo["count"] / info["total_count"] * 100 print(f" {pct:>4.0f}% {cat} ({cinfo['count']} allocs)") for funcname, short, lineno, cnt in sorted( cinfo["sources"], key=lambda x: x[3], reverse=True )[:3]: print(f" {funcname:35s} {short}:{lineno} (x{cnt})") # ---- CPU overhead analysis ------------------------------------------------ def analyze_cpu_overhead(events): """Analyze memcpy, checkpoint recomputation, and GPU utilization from trace events.""" cuda_total_us = 0 cuda_count = 0 memcpy_stats = defaultdict(lambda: {"dur_us": 0, "count": 0, "bytes": 0}) checkpoint_us = 0 checkpoint_count = 0 min_ts = float("inf") max_ts_end = 0 for ev in events: if ev.get("ph") != "X": continue cat = ev.get("cat", "") name = ev.get("name", "") dur = float(ev.get("dur", 0)) ts = float(ev.get("ts", 0)) end = ts + dur if ts < min_ts: min_ts = ts if end > max_ts_end: max_ts_end = end if cat == "kernel": cuda_total_us += dur cuda_count += 1 elif cat == "gpu_memcpy": if "DtoH" in name: direction = "GPU→CPU (offload)" elif "HtoD" in name: direction = "CPU→GPU (reload)" elif "DtoD" in name: direction = "GPU→GPU" else: direction = name memcpy_stats[direction]["dur_us"] += dur memcpy_stats[direction]["count"] += 1 nbytes = ev.get("args", {}).get("Bytes", 0) if nbytes: memcpy_stats[direction]["bytes"] += int(nbytes) elif cat in ("cpu_op", "python_function"): nl = name.lower() if "checkpoint" in nl or "recompute" in nl: checkpoint_us += dur checkpoint_count += 1 wall_us = max_ts_end - min_ts if max_ts_end > min_ts else 0 return { "wall_clock_us": wall_us, "cuda_total_us": cuda_total_us, "cuda_kernel_count": cuda_count, "memcpy_stats": dict(memcpy_stats), "checkpoint_us": checkpoint_us, "checkpoint_count": checkpoint_count, } def print_cpu_overhead(result, n_steps=2, label=""): if not result: return if label: print(f"\n{'=' * 75}") print(f" {label}") print(f"{'=' * 75}") else: print("\n CPU OVERHEAD ANALYSIS") wall = result["wall_clock_us"] cuda = result["cuda_total_us"] count = result["cuda_kernel_count"] gpu_util = cuda / wall * 100 if wall > 0 else 0 cpu_gap = wall - cuda print( f"\n Wall clock: {wall / 1e6:.2f}s (~{wall / n_steps / 1e6:.2f}s/step)" ) print(f" CUDA kernel time: {cuda / 1e6:.2f}s ({count:,} kernels)") print(f" GPU utilization: {gpu_util:.1f}%") print(f" CPU overhead: {cpu_gap / 1e6:.2f}s ({100 - gpu_util:.1f}%)") memcpy = result["memcpy_stats"] if memcpy: total_memcpy = sum(v["dur_us"] for v in memcpy.values()) total_bytes = sum(v["bytes"] for v in memcpy.values()) print( f"\n Memory transfers: {total_memcpy / 1e6:.3f}s " f"({total_bytes / 1e9:.2f}GB, {total_memcpy / wall * 100:.1f}% of wall)" ) for direction, info in sorted( memcpy.items(), key=lambda x: x[1]["dur_us"], reverse=True ): gb = info["bytes"] / 1e9 print( f" {direction:30s} {info['dur_us'] / 1e6:.3f}s " f"x{info['count']:>5} {gb:.2f}GB" ) if result["checkpoint_count"] > 0: print( f"\n Gradient checkpoint CPU ops: {result['checkpoint_us'] / 1e6:.3f}s " f"(x{result['checkpoint_count']})" ) # ---- Memory snapshot analysis --------------------------------------------- def load_snapshot(path): """Load a PyTorch CUDA memory snapshot from a pickle file. WARNING: This uses pickle.load() which can execute arbitrary code. Only load snapshot files that you generated yourself from trusted training runs. Never load snapshots from untrusted sources. """ snap_file = Path(path) / "snapshot.pickle" if Path(path).is_dir() else Path(path) if not snap_file.exists(): return None print(f"Loading {snap_file.name} ({snap_file.stat().st_size / 1e6:.0f} MB)...") with open(snap_file, "rb") as f: return pickle.load(f) # nosec B301 def _extract_python_frames(snapshot): """Extract Python source attribution from snapshot blocks with stacks='all'. The snapshot structure (when stacks='all') stores frames in: segments[i].blocks[j].history[k].frames = [(filename, lineno, name), ...] Returns a dict mapping (filename, function_name) -> {"bytes": int, "count": int} """ source_allocs = defaultdict(lambda: {"bytes": 0, "count": 0}) for seg in snapshot.get("segments", []): for block in seg.get("blocks", []): if block.get("state") != "active_allocated": continue size = block.get("size", 0) history = block.get("history", []) if not history: continue # Use the most recent allocation history entry last_hist = history[-1] frames = last_hist.get("frames", []) # Find the first Python frame (skip C++ frames) # Frames are tuples: (filename, lineno, name) attributed = False for frame in frames: if not isinstance(frame, (list, tuple)) or len(frame) < 3: continue filename, lineno, funcname = frame[0], frame[1], frame[2] # Skip internal torch/cuda frames to find user-level attribution fname_str = str(filename) if any( skip in fname_str for skip in [ "torch/cuda", "torch/_C", "torch/utils", "cuda/memory.py", "", ] ): continue key = (fname_str, funcname, lineno) source_allocs[key]["bytes"] += size source_allocs[key]["count"] += 1 attributed = True break # If no user frame found, use first available frame if not attributed and frames: frame = frames[0] if isinstance(frame, (list, tuple)) and len(frame) >= 3: key = (str(frame[0]), str(frame[2]), frame[1]) source_allocs[key]["bytes"] += size source_allocs[key]["count"] += 1 return dict(source_allocs) def _extract_source_file_summary(source_allocs): """Aggregate per-frame allocations to per-file level.""" file_allocs = defaultdict(lambda: {"bytes": 0, "count": 0, "functions": set()}) for (filename, funcname, _lineno), info in source_allocs.items(): file_allocs[filename]["bytes"] += info["bytes"] file_allocs[filename]["count"] += info["count"] file_allocs[filename]["functions"].add(funcname) return dict(file_allocs) def analyze_snapshot(snapshot): segments = snapshot.get("segments", []) total_reserved = sum(s.get("total_size", 0) for s in segments) total_allocated = sum(s.get("allocated_size", 0) for s in segments) # Active blocks active_blocks = [] for seg in segments: for block in seg.get("blocks", []): if block.get("state") == "active_allocated": active_blocks.append(block.get("size", 0)) # Allocation churn from trace trace = snapshot.get("device_traces", [[]])[0] size_counts = defaultdict(lambda: {"count": 0, "total": 0}) for ev in trace: if ev.get("action") == "alloc": sz = ev.get("size", 0) size_counts[sz]["count"] += 1 size_counts[sz]["total"] += sz # Python frame attribution source_allocs = _extract_python_frames(snapshot) file_summary = _extract_source_file_summary(source_allocs) return { "total_reserved": total_reserved, "total_allocated": total_allocated, "fragmentation_pct": (total_reserved - total_allocated) / total_reserved * 100 if total_reserved > 0 else 0, "n_segments": len(segments), "n_active_blocks": len(active_blocks), "active_bytes": sum(active_blocks), "largest_active": sorted(active_blocks, reverse=True)[:10], "alloc_churn": dict( sorted(size_counts.items(), key=lambda x: x[1]["total"], reverse=True)[:15] ), "n_trace_events": len(trace), "source_allocs": source_allocs, "file_summary": file_summary, } def print_memory_analysis(result, label=""): if label: print(f"\n{'=' * 75}") print(f" {label}") print(f"{'=' * 75}") reserved = result["total_reserved"] allocated = result["total_allocated"] print(f"\n Reserved: {reserved / 1e9:.2f} GB") print(f" Allocated: {allocated / 1e9:.2f} GB") print(f" Utilization: {allocated / reserved * 100:.1f}%" if reserved > 0 else "") print(f" Fragmentation: {result['fragmentation_pct']:.1f}%") print( f" Segments: {result['n_segments']}, Active blocks: {result['n_active_blocks']}" ) print("\n Largest active allocations:") for sz in result["largest_active"]: print(f" {sz / 1e6:>10.1f} MB") # Python source file attribution file_summary = result.get("file_summary", {}) if file_summary: print("\n Top allocations by source file:") print(f" {'Source file':<55} {'Alloc':>10} {'Count':>7}") print(f" {'-' * 75}") for fname, info in sorted( file_summary.items(), key=lambda x: x[1]["bytes"], reverse=True )[:15]: # Shorten path for display short = fname if len(short) > 55: parts = short.split("/") # Keep last 3 path components short = ".../" + "/".join(parts[-3:]) if len(short) > 55: short = short[:52] + "..." funcs = ", ".join(sorted(info["functions"])[:3]) sz = info["bytes"] unit = "MB" val = sz / 1e6 if val >= 1000: unit = "GB" val = sz / 1e9 print(f" {short:<55} {val:>8.1f}{unit} {info['count']:>7}") if funcs: print(f" functions: {funcs[:70]}") # Top allocations by function (more granular) source_allocs = result.get("source_allocs", {}) if source_allocs: print("\n Top allocations by function (with line numbers):") print(f" {'Function':<35} {'File:Line':<35} {'Size':>10}") print(f" {'-' * 82}") for (fname, funcname, lineno), info in sorted( source_allocs.items(), key=lambda x: x[1]["bytes"], reverse=True )[:15]: # Shorten filename short_file = fname parts = short_file.split("/") if len(parts) > 2: short_file = "/".join(parts[-2:]) loc = f"{short_file}:{lineno}" if len(loc) > 35: loc = "..." + loc[-32:] sz = info["bytes"] if sz >= 1e9: sz_str = f"{sz / 1e9:.2f}GB" else: sz_str = f"{sz / 1e6:.1f}MB" print(f" {funcname:<35} {loc:<35} {sz_str:>10}") print("\n Top allocation churn (alloc count x size):") print(f" {'Size':>12} {'Count':>8} {'Total churned':>14}") print(f" {'-' * 38}") for sz, info in result["alloc_churn"].items(): if sz >= 1e6: print( f" {sz / 1e6:>10.1f}MB {info['count']:>8} " f"{info['total'] / 1e9:>12.2f}GB" ) # ---- Peak memory timeline from trace events -------------------------------- def analyze_peak_memory(snapshot): """Walk through device_traces chronologically to find peak concurrent memory usage. The snapshot's segment data only captures end-of-step state. The device_traces record every alloc/free, letting us reconstruct peak usage and identify which allocation sources were live at that moment. """ traces = snapshot.get("device_traces", [[]])[0] if not traces: return None current = 0 peak = 0 peak_idx = 0 live_allocs = {} # addr -> (size, frames) peak_live = {} for i, ev in enumerate(traces): action = ev.get("action") addr = ev.get("addr", 0) size = ev.get("size", 0) if action == "alloc": current += size live_allocs[addr] = (size, ev.get("frames", [])) if current > peak: peak = current peak_idx = i peak_live = dict(live_allocs) elif action == "free_requested": if addr in live_allocs: current -= live_allocs[addr][0] del live_allocs[addr] # Categorize allocations at peak peak_categories = defaultdict(lambda: {"bytes": 0, "count": 0}) for _addr, (size, frames) in peak_live.items(): top = _get_top_python_frame(frames) if top: cat = _categorize_source(top["filename"], top["name"]) else: cat = "Unknown" peak_categories[cat]["bytes"] += size peak_categories[cat]["count"] += 1 return { "peak_bytes": peak, "peak_event_idx": peak_idx, "total_events": len(traces), "end_bytes": current, "peak_categories": dict(peak_categories), } def print_peak_memory(result, mem_result=None, label=""): if not result: return if label: print(f"\n{'=' * 75}") print(f" {label}") print(f"{'=' * 75}") peak_gb = result["peak_bytes"] / 1e9 end_gb = result["end_bytes"] / 1e9 # The device_traces only record allocations AFTER profiling starts. # Model weights and other persistent allocations are not tracked. # We can estimate the persistent baseline from snapshot allocated - peak_traced. persistent_gb = 0 if mem_result: persistent_gb = mem_result["total_allocated"] / 1e9 - end_gb total_peak_gb = persistent_gb + peak_gb print( f"\n Profiled peak (transient): {peak_gb:.2f} GB " f"(at event {result['peak_event_idx']:,} / {result['total_events']:,})" ) if persistent_gb > 0: print( f" Persistent baseline: {persistent_gb:.2f} GB " f"(model + optimizer, allocated before profiling)" ) print(f" Estimated total peak: {total_peak_gb:.2f} GB") print(f" Transient headroom: {peak_gb - end_gb:.2f} GB above end-of-trace") cats = result.get("peak_categories", {}) if cats: print("\n Allocations live at peak:") print(f" {'Category':<35} {'Size':>10} {'Count':>7}") print(f" {'-' * 55}") for cat, info in sorted( cats.items(), key=lambda x: x[1]["bytes"], reverse=True ): sz = info["bytes"] if sz >= 1e9: sz_str = f"{sz / 1e9:.2f} GB" else: sz_str = f"{sz / 1e6:.1f} MB" print(f" {cat:<35} {sz_str:>10} {info['count']:>7}") # ---- Fragmentation diagnosis ----------------------------------------------- def analyze_fragmentation(snapshot): """Analyze segment-level memory layout to explain fragmentation. Examines each CUDA segment for inactive (freed but unreturned) blocks, pinned small allocations that prevent segment merging, and the overall segment size distribution. """ segments = snapshot.get("segments", []) if not segments: return None total_reserved = 0 total_allocated = 0 total_inactive = 0 segment_sizes = [] inactive_gaps = [] # (gap_size, segment_size, active_around) pinned_fragments = [] # small active blocks surrounded by inactive for seg in segments: seg_size = seg.get("total_size", 0) total_reserved += seg_size segment_sizes.append(seg_size) blocks = seg.get("blocks", []) seg_active = 0 seg_inactive = 0 for bi, block in enumerate(blocks): bsize = block.get("size", 0) if block.get("state") == "active_allocated": seg_active += bsize total_allocated += bsize # Check if this small block is surrounded by inactive if bsize < 2 * 1024 * 1024: # < 2MB prev_inactive = bi > 0 and blocks[bi - 1].get("state") == "inactive" next_inactive = ( bi < len(blocks) - 1 and blocks[bi + 1].get("state") == "inactive" ) if prev_inactive and next_inactive: pinned_fragments.append((bsize, seg_size)) elif block.get("state") == "inactive": seg_inactive += bsize total_inactive += bsize inactive_gaps.append((bsize, seg_size)) # Classify segment sizes size_buckets = defaultdict(lambda: {"count": 0, "total": 0}) for sz in segment_sizes: if sz >= 1024 * 1024 * 1024: bucket = ">=1 GB" elif sz >= 256 * 1024 * 1024: bucket = "256MB-1GB" elif sz >= 64 * 1024 * 1024: bucket = "64-256MB" elif sz >= 2 * 1024 * 1024: bucket = "2-64MB" else: bucket = "<2MB" size_buckets[bucket]["count"] += 1 size_buckets[bucket]["total"] += sz # Large inactive gaps that could be reclaimed inactive_gaps.sort(key=lambda x: x[0], reverse=True) return { "total_reserved": total_reserved, "total_allocated": total_allocated, "total_inactive": total_inactive, "n_segments": len(segments), "segment_size_buckets": dict(size_buckets), "large_inactive_gaps": inactive_gaps[:20], "pinned_fragments": len(pinned_fragments), "expandable_segments_would_help": ( total_inactive > 0.1 * total_reserved and len(segments) > 10 ), } def print_fragmentation(result, gpu_capacity_gb=None, label=""): if not result: return if label: print(f"\n{'=' * 75}") print(f" {label}") print(f"{'=' * 75}") reserved = result["total_reserved"] allocated = result["total_allocated"] inactive = result["total_inactive"] frag_pct = inactive / reserved * 100 if reserved > 0 else 0 print( f"\n Reserved: {reserved / 1e9:.2f} GB across {result['n_segments']} segments" ) print(f" Allocated: {allocated / 1e9:.2f} GB") print(f" Inactive: {inactive / 1e9:.2f} GB ({frag_pct:.1f}% fragmentation)") if result["pinned_fragments"] > 0: print( f" Pinned small blocks (<2MB between inactive): " f"{result['pinned_fragments']} (prevent segment merging)" ) # Segment size distribution print("\n Segment size distribution:") bucket_order = [">=1 GB", "256MB-1GB", "64-256MB", "2-64MB", "<2MB"] for bucket in bucket_order: info = result["segment_size_buckets"].get(bucket) if info: print( f" {bucket:<12} {info['count']:>4} segments " f"{info['total'] / 1e9:>6.2f} GB" ) # Largest inactive gaps gaps = result.get("large_inactive_gaps", []) if gaps: print("\n Largest inactive gaps (freed but unreclaimable):") shown = 0 for gap_sz, seg_sz in gaps: if gap_sz >= 32 * 1024 * 1024 and shown < 10: print( f" {gap_sz / 1e6:>8.0f} MB gap in {seg_sz / 1e6:.0f} MB segment" ) shown += 1 # OOM risk assessment if gpu_capacity_gb: gpu_bytes = gpu_capacity_gb * 1e9 usable = gpu_bytes - (reserved - allocated) # capacity minus fragmented waste print(f"\n OOM Risk Assessment (GPU: {gpu_capacity_gb:.1f} GB):") print( f" Usable capacity: {usable / 1e9:.2f} GB " f"(GPU capacity minus {inactive / 1e9:.2f} GB fragmentation)" ) headroom = gpu_bytes - reserved print(f" Current headroom: {headroom / 1e9:.2f} GB") if headroom < 1.0e9: print(" ⚠ CRITICAL: <1 GB headroom — high OOM risk!") elif headroom < 2.0e9: print(" ⚠ WARNING: <2 GB headroom — moderate OOM risk") # Recommendation if result.get("expandable_segments_would_help"): print("\n → FIX: Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True") print(" This eliminates segment fragmentation by growing segments in-place,") print( f" which would reclaim up to {inactive / 1e9:.1f} GB of wasted memory." ) # ---- Sequence-length scaling analysis -------------------------------------- def analyze_scaling(mem_a, mem_b, churn_a, churn_b): """Compare per-tensor allocation sizes between two runs. When two profiles differ only in sequence length, this shows which tensor categories scale with sequence length by comparing the dominant tensor sizes. Total churn may differ due to different profiling windows, so we focus on per-tensor size ratios instead. """ if not churn_a or not churn_b: return None def _cat_sizes(churn): """Map category -> {largest_size, total_bytes, count}.""" cat_data = defaultdict(lambda: {"max_size": 0, "sizes": [], "count": 0}) for sz, info in churn.items(): for cat, cinfo in info.get("categories", {}).items(): cat_data[cat]["count"] += cinfo["count"] cat_data[cat]["sizes"].append((sz, cinfo["count"])) if sz > cat_data[cat]["max_size"]: cat_data[cat]["max_size"] = sz return dict(cat_data) cats_a = _cat_sizes(churn_a) cats_b = _cat_sizes(churn_b) all_cats = sorted( set(list(cats_a) + list(cats_b)), key=lambda c: max( cats_a.get(c, {"max_size": 0})["max_size"], cats_b.get(c, {"max_size": 0})["max_size"], ), reverse=True, ) scaling = [] for cat in all_cats: a = cats_a.get(cat) b = cats_b.get(cat) if not a or not b: continue a_max = a["max_size"] b_max = b["max_size"] if a_max > 1e6 and b_max > 1e6: # Only compare >1MB tensors tensor_ratio = b_max / a_max if a_max > 0 else None scaling.append( { "category": cat, "size_a_mb": a_max / 1e6, "size_b_mb": b_max / 1e6, "tensor_ratio": tensor_ratio, "count_a": a["count"], "count_b": b["count"], "scales_with_seqlen": tensor_ratio is not None and tensor_ratio > 1.05, } ) scaling.sort(key=lambda x: x["size_b_mb"], reverse=True) return scaling def print_scaling(scaling, label_a="Before", label_b="After", label=""): if not scaling: return if label: print(f"\n{'=' * 75}") print(f" {label}") print(f"{'=' * 75}") print("\n Per-tensor size comparison (largest tensor per category):") print( f" {'Category':<35} {'A size':>10} {'B size':>10} {'Ratio':>7} {'Scales?':>8}" ) print(f" {'-' * 73}") for entry in scaling: ratio_str = f"{entry['tensor_ratio']:.2f}x" if entry["tensor_ratio"] else "N/A" scales = "YES" if entry["scales_with_seqlen"] else "no" print( f" {entry['category']:<35} {entry['size_a_mb']:>8.1f}MB " f"{entry['size_b_mb']:>8.1f}MB {ratio_str:>7} {scales:>8}" ) # Summary seq_scaling = [e for e in scaling if e["scales_with_seqlen"]] constant = [e for e in scaling if not e["scales_with_seqlen"]] if seq_scaling: ratios = [e["tensor_ratio"] for e in seq_scaling if e["tensor_ratio"]] avg_ratio = sum(ratios) / len(ratios) if ratios else 0 print(f"\n Sequence-length scaling detected ({avg_ratio:.2f}x avg):") for e in seq_scaling: print( f" - {e['category']}: {e['size_a_mb']:.1f}MB -> " f"{e['size_b_mb']:.1f}MB ({e['tensor_ratio']:.2f}x)" ) if constant: print("\n Constant-size categories (do not scale with seq len):") for e in constant: print(f" - {e['category']}: {e['size_a_mb']:.1f}MB") # ---- Main ----------------------------------------------------------------- def main(): parser = argparse.ArgumentParser( description="Analyze axolotl training profiler output", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__, ) parser.add_argument( "path", help="Path to output directory (containing profiler_trace.json and/or " "snapshot.pickle) or directly to a trace file. " "Security note: snapshot.pickle uses pickle deserialization — " "only use files from your own trusted training runs.", ) parser.add_argument( "--compare", help="Path to second run for A/B comparison. " "Same security note as path: only use trusted snapshot files.", ) parser.add_argument( "--include-warmup", action="store_true", help="Include step 0 (warmup/compilation) in timing analysis", ) parser.add_argument( "--memory-only", action="store_true", help="Only analyze memory snapshot, skip trace", ) parser.add_argument( "--quick", action="store_true", help="Only load first 2M events for rapid analysis of large traces", ) parser.add_argument( "--gpu-gb", type=float, default=None, help="GPU total memory in GB (for OOM risk assessment). " "Auto-detected if not specified.", ) args = parser.parse_args() # Auto-detect GPU capacity if not specified gpu_capacity_gb = args.gpu_gb if gpu_capacity_gb is None: try: import torch if torch.cuda.is_available(): gpu_capacity_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 except Exception: pass skip = not args.include_warmup trace_result = None mem_result = None events = None # -- Trace analysis -- if not args.memory_only: events = load_trace(args.path, quick=args.quick) if events: trace_result = analyze_trace(events, skip_warmup=skip) if trace_result: print_trace_analysis(trace_result, label=f"Trace: {args.path}") if args.compare: events2 = load_trace(args.compare, quick=args.quick) if events2: result2 = analyze_trace(events2, skip_warmup=skip) if result2: print_trace_analysis( result2, label=f"Trace: {args.compare}" ) compare_traces(trace_result, result2) else: print(f" No profiler_trace.json found in {args.path}") # -- CPU overhead analysis (from trace) -- if events and trace_result: cpu_result = analyze_cpu_overhead(events) print_cpu_overhead( cpu_result, n_steps=trace_result["n_steps"], label=f"CPU Overhead: {args.path}", ) # -- Memory analysis -- snapshot = load_snapshot(args.path) churn_result = None churn2 = None if snapshot: mem_result = analyze_snapshot(snapshot) print_memory_analysis(mem_result, label=f"Memory: {args.path}") # Peak memory timeline peak_result = analyze_peak_memory(snapshot) print_peak_memory( peak_result, mem_result=mem_result, label=f"Peak Memory: {args.path}" ) # Fragmentation diagnosis frag_result = analyze_fragmentation(snapshot) print_fragmentation( frag_result, gpu_capacity_gb=gpu_capacity_gb, label=f"Fragmentation: {args.path}", ) # Allocation churn attribution churn_result = analyze_allocation_churn(snapshot) if churn_result: print_allocation_churn(churn_result, label=f"Allocation Churn: {args.path}") if args.compare: snapshot2 = load_snapshot(args.compare) if snapshot2: mem2 = analyze_snapshot(snapshot2) print_memory_analysis(mem2, label=f"Memory: {args.compare}") # Peak memory for comparison peak2 = analyze_peak_memory(snapshot2) print_peak_memory( peak2, mem_result=mem2, label=f"Peak Memory: {args.compare}" ) # Fragmentation for comparison frag2 = analyze_fragmentation(snapshot2) print_fragmentation( frag2, gpu_capacity_gb=gpu_capacity_gb, label=f"Fragmentation: {args.compare}", ) churn2 = analyze_allocation_churn(snapshot2) if churn2: print_allocation_churn( churn2, label=f"Allocation Churn: {args.compare}" ) # Memory comparison summary print("\n Memory comparison:") print( f" Reserved: {mem_result['total_reserved'] / 1e9:.2f} -> " f"{mem2['total_reserved'] / 1e9:.2f} GB" ) print( f" Allocated: {mem_result['total_allocated'] / 1e9:.2f} -> " f"{mem2['total_allocated'] / 1e9:.2f} GB" ) print( f" Frag: {mem_result['fragmentation_pct']:.1f}% -> " f"{mem2['fragmentation_pct']:.1f}%" ) if peak_result and peak2: print( f" Peak: {peak_result['peak_bytes'] / 1e9:.2f} -> " f"{peak2['peak_bytes'] / 1e9:.2f} GB" ) # Scaling analysis if churn_result and churn2: scaling = analyze_scaling(mem_result, mem2, churn_result, churn2) print_scaling( scaling, label_a=str(args.path), label_b=str(args.compare), label="Allocation Scaling Analysis", ) elif not args.memory_only: pass # trace-only is fine else: print(f" No snapshot.pickle found in {args.path}") # -- Summary -- if trace_result: print_summary(trace_result, mem_result=mem_result) if __name__ == "__main__": main()