From 29fa4dedbb7885bd0594937d40779b36ab5e6a12 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 10 Apr 2026 16:46:17 -0400 Subject: [PATCH] Gemma4 fixes and profiler (#3591) --- README.md | 30 +- docs/agents/model_architectures.md | 88 + docs/agents/sft.md | 24 + docs/multimodal.qmd | 35 + examples/gemma4/e2b-vision-lora.yaml | 62 + examples/qwen3.5/35b-a3b-moe-vision-lora.yaml | 62 + scripts/analyze_profile.py | 1518 +++++++++++++++++ src/axolotl/core/trainers/base.py | 17 + src/axolotl/integrations/liger/plugin.py | 50 + src/axolotl/monkeypatch/lora_kernels.py | 41 + 10 files changed, 1926 insertions(+), 1 deletion(-) create mode 100644 examples/gemma4/e2b-vision-lora.yaml create mode 100644 examples/qwen3.5/35b-a3b-moe-vision-lora.yaml create mode 100644 scripts/analyze_profile.py diff --git a/README.md b/README.md index 3d0710f0f..063ec8d3b 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ Features: **Requirements**: - NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU -- Python 3.11 +- Python >=3.11 (3.12 recommended) - PyTorch ≥2.9.1 ### Google Colab @@ -95,6 +95,34 @@ Features: ### Installation +#### Using uv (recommended) + +```bash +# install uv if you don't already have it installed +curl -LsSf https://astral.sh/uv/install.sh | sh +source $HOME/.local/bin/env + +# CUDA 12.8.1 tends to have better package compatibility +export UV_TORCH_BACKEND=cu128 + +# create a new virtual environment +uv venv --python 3.12 +source .venv/bin/activate + +uv pip install torch==2.10.0 torchvision +uv pip install --no-build-isolation axolotl[deepspeed] + +# recommended - install cut-cross-entropy +uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main" + +# (optional) - prefetch flash-attn2 and causal-conv1d kernels +uv run --python 3.12 python -c "from kernels import get_kernel; get_kernel('kernels-community/flash-attn2'); get_kernel('kernels-community/causal-conv1d')" + +# Download example axolotl configs, deepspeed configs +axolotl fetch examples +axolotl fetch deepspeed_configs # OPTIONAL +``` + #### Using pip ```bash diff --git a/docs/agents/model_architectures.md b/docs/agents/model_architectures.md index 426db4ce9..9412b4883 100644 --- a/docs/agents/model_architectures.md +++ b/docs/agents/model_architectures.md @@ -2,6 +2,64 @@ Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families. +## VLM (Vision Language Model) Quick Start + +All VLM configs require these four lines: +```yaml +processor_type: AutoProcessor +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false +``` + +Decision tree for VLM config: +```text +Is the model multimodal (has vision/audio encoder)? + ├─ YES: Add `freeze_mm_modules: true` if training text only + │ Add `chat_template: ` (e.g. gemma4, qwen3_5, gemma3) + │ LoRA: use regex `lora_target_modules` to restrict to language model + └─ NO: Train as a regular text model + +Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)? + ├─ YES: Add `lora_target_parameters` for expert LoRA + │ Consider ScatterMoE kernels (see Plugins section) + └─ NO: Standard LoRA config +``` + +## Plugins & Optimizations + +### Cut Cross Entropy (CCE) + +Computes loss from hidden states + lm_head weight without materializing the full logits tensor, saving significant VRAM. Install if not already present: + +```bash +uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main" +``` + +```yaml +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin +``` + +### ScatterMoE Kernels + +Fuses expert + LoRA computation into a single kernel for MoE models. Significant speedup for models with many experts. + +```yaml +plugins: + - axolotl.integrations.kernels.KernelsPlugin +use_kernels: true +use_scattermoe: true +experts_implementation: scattermoe + +# Expert LoRA targets (3D parameter tensors, not nn.Linear): +lora_target_parameters: + - experts.gate_up_proj + - experts.down_proj +``` + +Supported: Gemma4 (`gemma4_text`), Mixtral, Qwen MoE variants. The plugin auto-detects model type and routing function. Without ScatterMoE, expert LoRA still works but runs base expert matmul and LoRA as separate operations. + ## Gemma 4 **Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B` @@ -66,6 +124,36 @@ fsdp_config: experts_implementation: scattermoe ``` +### VLM (Vision) Training + +All Gemma4 models load as `Gemma4ForConditionalGeneration` with a vision tower. No custom `ProcessingStrategy` needed — the base class auto-detects the image token. + +```yaml +base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B +processor_type: AutoProcessor +freeze_mm_modules: true +chat_template: gemma4 + +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false +``` + +A starting VLM loss of ~8-15 is typical. In most runs, loss converges below 1.0 within ~30-50 steps, though results may vary across configurations. + +For the 26B-A4B MoE variant with ScatterMoE + expert LoRA + CCE, add: +```yaml +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + - axolotl.integrations.kernels.KernelsPlugin +use_kernels: true +use_scattermoe: true +experts_implementation: scattermoe +lora_target_parameters: + - experts.gate_up_proj + - experts.down_proj +``` + ### Common issues | Symptom | Cause | Fix | diff --git a/docs/agents/sft.md b/docs/agents/sft.md index 4d7b4e1a5..d3dfd39f7 100644 --- a/docs/agents/sft.md +++ b/docs/agents/sft.md @@ -91,6 +91,30 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go | FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` | | DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` | +## Profiling + +To profile training and identify optimization opportunities: + +```yaml +# Profile steps 3-7 (after warmup/autotuning settles) +profiler_steps_start: 3 +profiler_steps: 5 +``` + +This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`. +View the Chrome trace at `chrome://tracing`. + +To programmatically inspect the trace: +```bash +python scripts/analyze_profile.py output_dir/ +``` + +The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for: +- **Large matmul kernels**: candidates for fusion or quantization +- **Memory copies (H2D/D2H)**: unnecessary data movement +- **Small frequent kernels**: candidates for kernel fusion +- **Gaps between kernels**: pipeline bubbles from CPU overhead + Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd) ## File Map diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index e2a9440e8..aabff03f2 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -8,6 +8,7 @@ format: ## Supported Models +- [Gemma-4](#sec-gemma-4) *(NEW)* - [Mllama](#sec-mllama) - [Llama4](#sec-llama4) - [Pixtral](#sec-pixtral) @@ -138,6 +139,40 @@ base_model: mistralai/Voxtral-Mini-3B-2507 processor_type: VoxtralProcessor ``` +### Gemma-4 {#sec-gemma-4} + +All Gemma 4 variants (E2B, E4B, 26B-A4B, 31B) load as multimodal models even for text-only training. + +```yaml +base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B, 31B + +chat_template: gemma4 +freeze_mm_modules: true # freeze vision/audio encoders for text-only or vision LoRA + +# For the 26B-A4B MoE model, enable ScatterMoE and expert LoRA: +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + - axolotl.integrations.kernels.KernelsPlugin +use_kernels: true +use_scattermoe: true +experts_implementation: scattermoe + +lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' + +# MoE expert LoRA (3D tensors, not nn.Linear) — only for 26B-A4B: +lora_target_parameters: + - experts.gate_up_proj + - experts.down_proj +``` + +::: {.callout-warning} +Gemma 4 VLM training starts with high loss (~8-15). This is expected — see the [training stability guide](training_stability.qmd) for details. +::: + +::: {.callout-tip} +For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. However, when `activation_offloading: true`, `ddp_find_unused_parameters` is skipped (checkpoint wrappers conflict with it); use `freeze_mm_modules: true` instead to handle unused vision/audio params. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`. +::: + ### Gemma-3 {#sec-gemma-3} ::: {.callout-tip} diff --git a/examples/gemma4/e2b-vision-lora.yaml b/examples/gemma4/e2b-vision-lora.yaml new file mode 100644 index 000000000..c779aaea5 --- /dev/null +++ b/examples/gemma4/e2b-vision-lora.yaml @@ -0,0 +1,62 @@ +# Gemma 4 E2B Vision LoRA +# +# Fine-tuning LM LoRA adapters on multimodal Gemma4 with vision/multimodal modules frozen. +# Uses the base ProcessingStrategy (auto-detects image_token from processor). + +base_model: google/gemma-4-E2B-it +processor_type: AutoProcessor +freeze_mm_modules: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin +strict: false + +# Required for vision/multimodal training +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: gemma4 +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:100] + +val_set_size: 0 +output_dir: ./outputs/gemma4-e2b-vision-lora + +adapter: lora +sequence_len: 2048 +pad_to_sequence_len: false + +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0 +# Target language model only — vision encoder is frozen via freeze_mm_modules +lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +max_steps: 10 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +logging_steps: 1 +sdp_attention: true + +warmup_ratio: 0.1 +weight_decay: 0.0 + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: diff --git a/examples/qwen3.5/35b-a3b-moe-vision-lora.yaml b/examples/qwen3.5/35b-a3b-moe-vision-lora.yaml new file mode 100644 index 000000000..a7c85f785 --- /dev/null +++ b/examples/qwen3.5/35b-a3b-moe-vision-lora.yaml @@ -0,0 +1,62 @@ +# Qwen 3.5 35B-A3B MoE Vision LoRA +# +# Vision fine-tuning of the hybrid DeltaNet + Attention MoE model. +# 256 experts, 8 active per token, with early-fusion vision support. + +base_model: Qwen/Qwen3.5-35B-A3B +processor_type: AutoProcessor + +# Required for vision/multimodal training +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: qwen3_5 +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:100] + +val_set_size: 0 +output_dir: ./outputs/qwen35-35b-a3b-vision-lora + +adapter: lora +sequence_len: 4096 +pad_to_sequence_len: false + +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - down_proj + - up_proj + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +max_steps: 10 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +weight_decay: 0.0 + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: diff --git a/scripts/analyze_profile.py b/scripts/analyze_profile.py new file mode 100644 index 000000000..df922502f --- /dev/null +++ b/scripts/analyze_profile.py @@ -0,0 +1,1518 @@ +#!/usr/bin/env python3 +""" +Axolotl Training Profiler Analyzer +=================================== + +Analyzes PyTorch profiler output from axolotl training runs (profiler_steps config). +Produces breakdowns by CUDA kernel category, identifies bottlenecks, and optionally +compares two traces (e.g. before/after optimization). + +Supports both: + - profiler_trace.json (torch.profiler Chrome trace -- timing analysis) + - snapshot.pickle (torch.cuda.memory._snapshot -- memory analysis) + +Usage: + # Analyze a single trace + python analyze_profile.py outputs/qwen35_moe_profile/ + + # Compare before vs after + python analyze_profile.py outputs/before/ --compare outputs/after/ + + # Include step 0 (warmup/compilation) in analysis + python analyze_profile.py outputs/run/ --include-warmup + + # Memory-only analysis + python analyze_profile.py outputs/run/ --memory-only + + # Quick mode (first 2M events only, for large traces) + python analyze_profile.py outputs/run/ --quick +""" + +import argparse +import json +import pickle # nosec B403 +import time +from collections import defaultdict +from pathlib import Path + +# ---- Kernel categorization ------------------------------------------------ + +KERNEL_CATEGORIES = [ + # ScatterMoE -- ordered most specific first + ( + "ScatterMoE bwd LoRA (split)", + ["_group_bwd_lora_split", "_group_bwd_da", "_group_bwd_db"], + ), + ("ScatterMoE bwd LoRA (fused)", ["_group_bwd_lora"]), + ("ScatterMoE bwd dX", ["_scatter2scatter_lora_dx"]), + ("ScatterMoE fwd", ["_scatter2scatter_lora"]), + # Quantization + ("BnB Dequantization", ["dequantize", "kDequantizeBlockwise"]), + # Attention + ("Flash Attention", ["flash", "fmha"]), + # Loss + ("CCE Loss", ["_cce_"]), + # LoRA fused kernels (autograd.Function based) + ("LoRA QKV Kernel", ["lora_qkv"]), + ("LoRA O Kernel", ["lora_o"]), + ("LoRA MLP Kernel", ["lora_mlp"]), + # LoRA activation kernels (SwiGLU/GEGLU Triton kernels) + ("LoRA Activation (SwiGLU/GEGLU)", ["swiglu", "geglu"]), + # DoRA weight norm + ("DoRA Weight Norm", ["linalg_norm", "dora_scale"]), + # Compute + ("GEMM/CUTLASS", ["cutlass", "gemm", "gemv", "cublas"]), + ("Triton (norms etc)", ["triton"]), + ("Conv1d", ["conv1d", "causal_conv"]), + # Optimizer + ("Optimizer", ["adam", "optim"]), + # Dtype conversion (fp32→bf16 LoRA matrix casts etc) + ("Dtype Conversion", ["_to_copy", "to_copy"]), + # Memory + ("Elementwise/Fill", ["fill", "elementwise", "cast", "copy_kernel"]), + ("Memory ops", ["memcpy", "memset"]), + # Routing + ("TopK/Sort", ["topk", "sort"]), + ("Index/Gather/Scatter", ["index", "gather", "scatter"]), +] + +# Categories to keep when pre-filtering during streaming load +_KEEP_CATS = {"kernel", "gpu_memcpy", "cpu_op", "python_function", "ac2g", "Runtime"} + + +def categorize_kernel(name): + nl = name.lower() + for cat_name, patterns in KERNEL_CATEGORIES: + if any(p in nl for p in patterns): + return cat_name + return "Other" + + +# ---- Trace loading --------------------------------------------------------- + + +def _try_ijson_load(trace_file, quick=False, max_events=2_000_000): + """Stream-parse trace JSON with ijson. Returns filtered events list.""" + try: + import ijson + except ImportError: + return None + + events = [] + count = 0 + with open(trace_file, "rb") as f: + for ev in ijson.items(f, "traceEvents.item"): + count += 1 + cat = ev.get("cat", "") + # Pre-filter: only keep categories we care about + if cat in _KEEP_CATS or ev.get("ph") == "M": + events.append(ev) + if quick and count >= max_events: + print(f" --quick: stopped after {count:,} events") + break + return events, count + + +def load_trace(path, quick=False): + trace_file = ( + Path(path) / "profiler_trace.json" if Path(path).is_dir() else Path(path) + ) + if not trace_file.exists(): + return None + size_gb = trace_file.stat().st_size / 1e9 + print(f"Loading {trace_file.name} ({size_gb:.1f} GB)...") + t0 = time.monotonic() + + # Try streaming parser first for large files + if size_gb > 0.5: + result = _try_ijson_load(trace_file, quick=quick) + if result is not None: + events, total_count = result + elapsed = time.monotonic() - t0 + print( + f" {total_count:,} total events, {len(events):,} kept " + f"(streamed in {elapsed:.1f}s)" + ) + return events + + # Fallback: standard json.load + with open(trace_file) as f: + data = json.load(f) + all_events = data.get("traceEvents", []) + elapsed = time.monotonic() - t0 + + if quick: + all_events = all_events[:2_000_000] + print(f" --quick: limited to first {len(all_events):,} events") + + # Pre-filter + events = [ + ev + for ev in all_events + if ev.get("cat", "") in _KEEP_CATS or ev.get("ph") == "M" + ] + print( + f" {len(all_events):,} total events, {len(events):,} kept " + f"(loaded in {elapsed:.1f}s)" + ) + return events + + +# ---- Trace analysis ------------------------------------------------------- + + +def _estimate_n_steps(cuda_events): + """Estimate the number of training steps from CUDA event timestamps. + + Detects step boundaries by looking for large gaps (>2x median gap) in the + sorted timestamp sequence of CUDA kernels. + """ + if len(cuda_events) < 100: + return 1 + timestamps = sorted(float(ev.get("ts", 0)) for ev in cuda_events) + # Compute gaps between consecutive events + gaps = [timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)] + if not gaps: + return 1 + median_gap = sorted(gaps)[len(gaps) // 2] + # A step boundary has a gap much larger than the median inter-kernel gap + threshold = max(median_gap * 50, 100_000) # at least 100ms + n_boundaries = sum(1 for g in gaps if g > threshold) + return max(n_boundaries + 1, 1) + + +def analyze_trace(events, skip_warmup=True): + cuda_events = [ + ev + for ev in events + if ev.get("ph") == "X" and ev.get("cat") in ("kernel", "gpu_memcpy") + ] + if not cuda_events: + print(" No CUDA kernel events found!") + return None + + cutoff_ts = None + + if skip_warmup and len(cuda_events) > 1000: + timestamps = sorted(set(float(ev.get("ts", 0)) for ev in cuda_events)) + min_ts, max_ts = timestamps[0], timestamps[-1] + total_span = max_ts - min_ts + + # Step 0 is warmup (Triton compilation + autotune). It's typically + # the slowest step by far. Use 45% of wall-clock as cutoff -- step 0 + # usually takes >50% of total time when it includes compilation. + cutoff_ts = min_ts + total_span * 0.45 + before = len(cuda_events) + cuda_events = [ev for ev in cuda_events if float(ev.get("ts", 0)) > cutoff_ts] + print(f" Excluding step 0 (warmup): {before:,} -> {len(cuda_events):,} events") + + n_steps_profiled = _estimate_n_steps(cuda_events) + + # Aggregate by kernel (cast to float to handle ijson Decimal values) + kernel_stats = defaultdict(lambda: {"total_us": 0.0, "count": 0, "max_us": 0.0}) + for ev in cuda_events: + name = ev.get("name", "unknown") + dur = float(ev.get("dur", 0)) + kernel_stats[name]["total_us"] += dur + kernel_stats[name]["count"] += 1 + kernel_stats[name]["max_us"] = max(kernel_stats[name]["max_us"], dur) + + total_cuda = sum(v["total_us"] for v in kernel_stats.values()) + + # Group by category + cat_stats = defaultdict(lambda: {"total_us": 0, "count": 0}) + for name, info in kernel_stats.items(): + cat = categorize_kernel(name) + cat_stats[cat]["total_us"] += info["total_us"] + cat_stats[cat]["count"] += info["count"] + + # Fill/zero_ analysis: find FillFunctor kernels and group by tensor size + fill_by_size = defaultdict(lambda: {"total_us": 0, "count": 0}) + for ev in cuda_events: + name = ev.get("name", "") + if "FillFunctor" not in name and "fill" not in name.lower(): + continue + # Extract Input Dims from args if present + args = ev.get("args", {}) + input_dims = args.get("Input Dims", args.get("input_dims", "unknown")) + if isinstance(input_dims, list): + input_dims = str(input_dims) + dur = float(ev.get("dur", 0)) + fill_by_size[input_dims]["total_us"] += dur + fill_by_size[input_dims]["count"] += 1 + + # CPU op analysis for wall-clock estimation (apply same warmup cutoff) + cpu_ops = [ + ev + for ev in events + if ev.get("ph") == "X" and ev.get("cat") in ("cpu_op", "python_function") + ] + if cutoff_ts is not None: + cpu_ops = [ev for ev in cpu_ops if float(ev.get("ts", 0)) > cutoff_ts] + wall_clock_us = 0 + if cpu_ops: + ts_sorted = sorted(cpu_ops, key=lambda e: float(e.get("ts", 0))) + min_cpu_ts = float(ts_sorted[0].get("ts", 0)) + max_cpu_end = max( + float(e.get("ts", 0)) + float(e.get("dur", 0)) for e in ts_sorted + ) + wall_clock_us = max_cpu_end - min_cpu_ts + + return { + "total_cuda_us": total_cuda, + "n_steps": n_steps_profiled, + "categories": dict(cat_stats), + "kernel_stats": dict(kernel_stats), + "n_events": len(cuda_events), + "fill_by_size": dict(fill_by_size), + "wall_clock_us": wall_clock_us, + } + + +def print_trace_analysis(result, label=""): + total = result["total_cuda_us"] + n = result["n_steps"] + + if label: + print(f"\n{'=' * 75}") + print(f" {label}") + print(f"{'=' * 75}") + + print( + f"\n CUDA kernel time: {total / 1e6:.2f}s over {n} steps " + f"(~{total / n / 1e6:.2f}s/step)" + ) + + if result.get("wall_clock_us"): + wc = result["wall_clock_us"] + print( + f" Wall clock span: {wc / 1e6:.2f}s over {n} steps " + f"(~{wc / n / 1e6:.2f}s/step)" + ) + + print(f"\n {'Category':<40} {'Total':>9} {'%':>6} {'Count':>7} {'Per step':>9}") + print(f" {'-' * 75}") + for cat, info in sorted( + result["categories"].items(), key=lambda x: x[1]["total_us"], reverse=True + ): + pct = info["total_us"] / total * 100 + ps = info["total_us"] / n / 1000 + print( + f" {cat:<40} {info['total_us'] / 1000:>8.1f}ms {pct:>5.1f}% " + f"{info['count']:>7} {ps:>7.1f}ms" + ) + + print("\n Top 15 individual kernels:") + for name, info in sorted( + result["kernel_stats"].items(), key=lambda x: x[1]["total_us"], reverse=True + )[:15]: + pct = info["total_us"] / total * 100 + avg = info["total_us"] / info["count"] / 1000 + print( + f" {name[:62]:<62} {info['total_us'] / 1000:>7.1f}ms " + f"({pct:>4.1f}%) x{info['count']:<5} avg={avg:.3f}ms" + ) + + # Fill/zero_ breakdown + fill_data = result.get("fill_by_size", {}) + if fill_data: + total_fill_us = sum(v["total_us"] for v in fill_data.values()) + if total_fill_us > 0: + print( + f"\n Fill/zero_ kernel breakdown by tensor size " + f"(total: {total_fill_us / 1000:.1f}ms, " + f"{total_fill_us / total * 100:.1f}% of CUDA time):" + ) + print(f" {'Input Dims':<50} {'Time':>9} {'Count':>7}") + print(f" {'-' * 70}") + for dims, info in sorted( + fill_data.items(), key=lambda x: x[1]["total_us"], reverse=True + )[:10]: + dims_str = str(dims)[:50] + print( + f" {dims_str:<50} {info['total_us'] / 1000:>7.1f}ms " + f"{info['count']:>7}" + ) + + +def print_summary(result, mem_result=None): + """Print optimization recommendations and summary.""" + total = result["total_cuda_us"] + n = result["n_steps"] + + print(f"\n{'=' * 75}") + print(" SUMMARY & RECOMMENDATIONS") + print(f"{'=' * 75}") + + # Estimated per-step wall clock + wc = result.get("wall_clock_us", 0) + if wc > 0: + print(f"\n Estimated per-step wall clock: {wc / n / 1e6:.2f}s") + else: + print(f"\n Estimated per-step CUDA time: {total / n / 1e6:.2f}s") + + # Memory utilization + if mem_result: + reserved = mem_result["total_reserved"] + allocated = mem_result["total_allocated"] + if reserved > 0: + util_pct = allocated / reserved * 100 + print( + f" Memory utilization: {util_pct:.1f}% " + f"({allocated / 1e9:.2f} / {reserved / 1e9:.2f} GB)" + ) + + # Build recommendations + recommendations = [] + + cats_sorted = sorted( + result["categories"].items(), key=lambda x: x[1]["total_us"], reverse=True + ) + + # Check top category + if cats_sorted: + top_cat, top_info = cats_sorted[0] + top_pct = top_info["total_us"] / total * 100 + if top_pct > 30: + if "GEMM" in top_cat or "CUTLASS" in top_cat: + recommendations.append( + f"GEMM/matmul dominates ({top_pct:.0f}%). " + f"Consider FP8 training, LoRA (fewer params), " + f"or smaller batch size to reduce compute." + ) + elif "Attention" in top_cat: + recommendations.append( + f"Attention dominates ({top_pct:.0f}%). " + f"Ensure FlashAttention v2/v3 is active. " + f"Consider reducing sequence length or using sliding window." + ) + elif "Fill" in top_cat or "Elementwise" in top_cat: + recommendations.append( + f"Elementwise/Fill ops dominate ({top_pct:.0f}%). " + f"Enable kernel fusion (Liger, torch.compile) to reduce " + f"memory-bound elementwise operations." + ) + elif "ScatterMoE" in top_cat: + recommendations.append( + f"MoE routing/scatter dominates ({top_pct:.0f}%). " + f"Check expert count and capacity factor. " + f"Verify ScatterMoE kernels are using optimal block sizes." + ) + else: + recommendations.append( + f"'{top_cat}' dominates ({top_pct:.0f}%). " + f"Focus optimization efforts here first." + ) + + # Check memory ops + for cat, info in cats_sorted: + pct = info["total_us"] / total * 100 + if "Memory" in cat and pct > 10: + recommendations.append( + f"Memory ops are {pct:.0f}% of CUDA time. " + f"Consider gradient checkpointing, reducing activation " + f"recomputation, or pinned memory for data loading." + ) + break + + # Check fill overhead + fill_data = result.get("fill_by_size", {}) + total_fill_us = sum(v["total_us"] for v in fill_data.values()) + fill_pct = total_fill_us / total * 100 if total > 0 else 0 + if fill_pct > 5: + recommendations.append( + f"Fill/zero_ kernels consume {fill_pct:.1f}% of CUDA time. " + f"Large zero-fills suggest excessive tensor allocation. " + f"Consider reusing buffers or lazy initialization." + ) + + # Check fragmentation + if mem_result and mem_result.get("fragmentation_pct", 0) > 20: + frag = mem_result["fragmentation_pct"] + recommendations.append( + f"Memory fragmentation is {frag:.0f}%. " + f"Use PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True " + f"or max_split_size_mb to reduce fragmentation." + ) + + # GPU utilization hint from wall clock vs CUDA time + if wc > 0 and total > 0: + gpu_util = total / wc * 100 + if gpu_util < 50: + recommendations.append( + f"GPU utilization is ~{gpu_util:.0f}% (CUDA time vs wall clock). " + f"CPU-side bottleneck likely. Profile data loading, " + f"reward computation, or weight sync overhead." + ) + + # Print top 3 + print("\n Top optimization recommendations:") + if not recommendations: + print(" No major issues detected.") + for i, rec in enumerate(recommendations[:3], 1): + print(f" {i}. {rec}") + + print() + + +def compare_traces(before, after): + print(f"\n{'=' * 75}") + print(" COMPARISON") + print(f"{'=' * 75}") + + tb = before["total_cuda_us"] + ta = after["total_cuda_us"] + nb = before["n_steps"] + na = after["n_steps"] + + print( + f"\n Per-step CUDA time: {tb / nb / 1e6:.2f}s -> {ta / na / 1e6:.2f}s " + f"({(ta / na - tb / nb) / (tb / nb) * 100:+.1f}%)" + ) + + all_cats = sorted( + set(list(before["categories"]) + list(after["categories"])), + key=lambda c: before["categories"].get(c, {"total_us": 0})["total_us"], + reverse=True, + ) + + print( + f"\n {'Category':<35} {'Before/step':>11} {'After/step':>11} " + f"{'Delta':>9} {'Speedup':>8}" + ) + print(f" {'-' * 78}") + for cat in all_cats: + b = before["categories"].get(cat, {"total_us": 0})["total_us"] / nb + a = after["categories"].get(cat, {"total_us": 0})["total_us"] / na + delta = a - b + speedup = b / a if a > 0 else float("inf") + if b > tb / nb * 0.003 or a > ta / na * 0.003: + print( + f" {cat:<35} {b / 1000:>9.1f}ms {a / 1000:>9.1f}ms " + f"{delta / 1000:>+8.1f}ms {speedup:>7.2f}x" + ) + + +# ---- Allocation churn attribution ----------------------------------------- + + +def _short_path(p): + """Shorten a file path for display.""" + if "/site-packages/" in p: + return "..." + p.split("/site-packages/")[1] + if "/axolotl/src/" in p: + return "axolotl/" + p.split("/axolotl/src/")[1] + if "/axolotl/" in p: + return "axolotl/" + p.split("/axolotl/")[1] + if len(p) > 60: + parts = p.split("/") + return ".../" + "/".join(parts[-3:]) + return p + + +def _get_top_python_frame(frames): + """Get the first meaningful Python frame from a trace frame list. + + Skips internal torch/cuda/autograd frames to find the user-level code + that triggered the allocation. Falls back to first Python frame. + """ + skip_patterns = [ + "torch/autograd/", + "torch/utils/checkpoint", + "torch/nn/modules/module.py", + "torch/_dynamo/", + "torch/_compile", + "torch/_ops", + "torch/_functorch/", + "torch/_inductor/", + "torch/library", + "torch/cuda", + "torch/_C", + "", + "fire/core.py", + "runpy", + ] + py_frames = [ + f + for f in frames + if isinstance(f, dict) and f.get("filename", "").endswith(".py") + ] + for fr in py_frames: + fname = fr.get("filename", "") + if not any(p in fname for p in skip_patterns): + return fr + return py_frames[0] if py_frames else None + + +def _categorize_source(fname, funcname): + """Assign a human-readable category to an allocation source.""" + if "checkpoint" in fname.lower() or "backward" in funcname: + return "Gradient checkpoint recompute" + if "bitsandbytes" in fname: + return "BnB dequantization" + if "scattermoe" in fname or "scatter" in funcname.lower(): + return "ScatterMoE LoRA" + if "adam" in funcname.lower() or "optim" in fname.lower() or "inductor" in fname: + return "Optimizer" + if "norm" in funcname.lower(): + return "LayerNorm" + if "fla/" in fname or "gated_delta" in fname: + return "FLA linear attention" + return "Other" + + +def analyze_allocation_churn(snapshot): + """Group allocation churn by Python source attribution. + + For each major churn size, identifies which Python code creates + the tensors (gradient checkpointing, BnB dequant, LoRA conversion, etc). + """ + traces = snapshot.get("device_traces", [[]])[0] + if not traces: + return None + + # Find the top churn sizes + size_counts = defaultdict(int) + for ev in traces: + if ev.get("action") == "alloc": + size_counts[ev["size"]] += 1 + + # Top 5 sizes by total churn (count × size) + top_sizes = sorted(size_counts.items(), key=lambda x: x[0] * x[1], reverse=True)[:5] + + results = {} + for target_sz, total_count in top_sizes: + mb = target_sz / 1e6 + if mb < 1.0: + continue + + frame_groups = defaultdict(lambda: {"count": 0}) + for ev in traces: + if ev.get("action") == "alloc" and ev.get("size") == target_sz: + top = _get_top_python_frame(ev.get("frames", [])) + if top: + key = (top["filename"], top["name"], top.get("line", 0)) + else: + key = ("", "", 0) + frame_groups[key]["count"] += 1 + + # Group into categories + categories = defaultdict(lambda: {"count": 0, "sources": []}) + for (fname, funcname, lineno), info in frame_groups.items(): + cat = _categorize_source(fname, funcname) + categories[cat]["count"] += info["count"] + categories[cat]["sources"].append( + (funcname, _short_path(fname), lineno, info["count"]) + ) + + results[target_sz] = { + "total_count": total_count, + "total_churn_gb": total_count * target_sz / 1e9, + "categories": dict(categories), + } + + return results + + +def print_allocation_churn(results, label=""): + if not results: + return + + if label: + print(f"\n{'=' * 75}") + print(f" {label}") + print(f"{'=' * 75}") + else: + print("\n ALLOCATION CHURN BY SOURCE") + + for target_sz in sorted( + results, key=lambda s: results[s]["total_churn_gb"], reverse=True + ): + info = results[target_sz] + mb = target_sz / 1e6 + print( + f"\n {mb:.1f}MB × {info['total_count']:,} = {info['total_churn_gb']:.0f}GB churn:" + ) + for cat, cinfo in sorted( + info["categories"].items(), key=lambda x: x[1]["count"], reverse=True + ): + pct = cinfo["count"] / info["total_count"] * 100 + print(f" {pct:>4.0f}% {cat} ({cinfo['count']} allocs)") + for funcname, short, lineno, cnt in sorted( + cinfo["sources"], key=lambda x: x[3], reverse=True + )[:3]: + print(f" {funcname:35s} {short}:{lineno} (x{cnt})") + + +# ---- CPU overhead analysis ------------------------------------------------ + + +def analyze_cpu_overhead(events): + """Analyze memcpy, checkpoint recomputation, and GPU utilization from trace events.""" + cuda_total_us = 0 + cuda_count = 0 + memcpy_stats = defaultdict(lambda: {"dur_us": 0, "count": 0, "bytes": 0}) + checkpoint_us = 0 + checkpoint_count = 0 + min_ts = float("inf") + max_ts_end = 0 + + for ev in events: + if ev.get("ph") != "X": + continue + cat = ev.get("cat", "") + name = ev.get("name", "") + dur = float(ev.get("dur", 0)) + ts = float(ev.get("ts", 0)) + end = ts + dur + if ts < min_ts: + min_ts = ts + if end > max_ts_end: + max_ts_end = end + + if cat == "kernel": + cuda_total_us += dur + cuda_count += 1 + elif cat == "gpu_memcpy": + if "DtoH" in name: + direction = "GPU→CPU (offload)" + elif "HtoD" in name: + direction = "CPU→GPU (reload)" + elif "DtoD" in name: + direction = "GPU→GPU" + else: + direction = name + memcpy_stats[direction]["dur_us"] += dur + memcpy_stats[direction]["count"] += 1 + nbytes = ev.get("args", {}).get("Bytes", 0) + if nbytes: + memcpy_stats[direction]["bytes"] += int(nbytes) + elif cat in ("cpu_op", "python_function"): + nl = name.lower() + if "checkpoint" in nl or "recompute" in nl: + checkpoint_us += dur + checkpoint_count += 1 + + wall_us = max_ts_end - min_ts if max_ts_end > min_ts else 0 + + return { + "wall_clock_us": wall_us, + "cuda_total_us": cuda_total_us, + "cuda_kernel_count": cuda_count, + "memcpy_stats": dict(memcpy_stats), + "checkpoint_us": checkpoint_us, + "checkpoint_count": checkpoint_count, + } + + +def print_cpu_overhead(result, n_steps=2, label=""): + if not result: + return + if label: + print(f"\n{'=' * 75}") + print(f" {label}") + print(f"{'=' * 75}") + else: + print("\n CPU OVERHEAD ANALYSIS") + + wall = result["wall_clock_us"] + cuda = result["cuda_total_us"] + count = result["cuda_kernel_count"] + gpu_util = cuda / wall * 100 if wall > 0 else 0 + cpu_gap = wall - cuda + + print( + f"\n Wall clock: {wall / 1e6:.2f}s (~{wall / n_steps / 1e6:.2f}s/step)" + ) + print(f" CUDA kernel time: {cuda / 1e6:.2f}s ({count:,} kernels)") + print(f" GPU utilization: {gpu_util:.1f}%") + print(f" CPU overhead: {cpu_gap / 1e6:.2f}s ({100 - gpu_util:.1f}%)") + + memcpy = result["memcpy_stats"] + if memcpy: + total_memcpy = sum(v["dur_us"] for v in memcpy.values()) + total_bytes = sum(v["bytes"] for v in memcpy.values()) + print( + f"\n Memory transfers: {total_memcpy / 1e6:.3f}s " + f"({total_bytes / 1e9:.2f}GB, {total_memcpy / wall * 100:.1f}% of wall)" + ) + for direction, info in sorted( + memcpy.items(), key=lambda x: x[1]["dur_us"], reverse=True + ): + gb = info["bytes"] / 1e9 + print( + f" {direction:30s} {info['dur_us'] / 1e6:.3f}s " + f"x{info['count']:>5} {gb:.2f}GB" + ) + + if result["checkpoint_count"] > 0: + print( + f"\n Gradient checkpoint CPU ops: {result['checkpoint_us'] / 1e6:.3f}s " + f"(x{result['checkpoint_count']})" + ) + + +# ---- Memory snapshot analysis --------------------------------------------- + + +def load_snapshot(path): + """Load a PyTorch CUDA memory snapshot from a pickle file. + + WARNING: This uses pickle.load() which can execute arbitrary code. + Only load snapshot files that you generated yourself from trusted + training runs. Never load snapshots from untrusted sources. + """ + snap_file = Path(path) / "snapshot.pickle" if Path(path).is_dir() else Path(path) + if not snap_file.exists(): + return None + print(f"Loading {snap_file.name} ({snap_file.stat().st_size / 1e6:.0f} MB)...") + with open(snap_file, "rb") as f: + return pickle.load(f) # nosec B301 + + +def _extract_python_frames(snapshot): + """Extract Python source attribution from snapshot blocks with stacks='all'. + + The snapshot structure (when stacks='all') stores frames in: + segments[i].blocks[j].history[k].frames = [(filename, lineno, name), ...] + + Returns a dict mapping (filename, function_name) -> {"bytes": int, "count": int} + """ + source_allocs = defaultdict(lambda: {"bytes": 0, "count": 0}) + + for seg in snapshot.get("segments", []): + for block in seg.get("blocks", []): + if block.get("state") != "active_allocated": + continue + size = block.get("size", 0) + history = block.get("history", []) + if not history: + continue + + # Use the most recent allocation history entry + last_hist = history[-1] + frames = last_hist.get("frames", []) + + # Find the first Python frame (skip C++ frames) + # Frames are tuples: (filename, lineno, name) + attributed = False + for frame in frames: + if not isinstance(frame, (list, tuple)) or len(frame) < 3: + continue + filename, lineno, funcname = frame[0], frame[1], frame[2] + # Skip internal torch/cuda frames to find user-level attribution + fname_str = str(filename) + if any( + skip in fname_str + for skip in [ + "torch/cuda", + "torch/_C", + "torch/utils", + "cuda/memory.py", + "", + ] + ): + continue + key = (fname_str, funcname, lineno) + source_allocs[key]["bytes"] += size + source_allocs[key]["count"] += 1 + attributed = True + break + + # If no user frame found, use first available frame + if not attributed and frames: + frame = frames[0] + if isinstance(frame, (list, tuple)) and len(frame) >= 3: + key = (str(frame[0]), str(frame[2]), frame[1]) + source_allocs[key]["bytes"] += size + source_allocs[key]["count"] += 1 + + return dict(source_allocs) + + +def _extract_source_file_summary(source_allocs): + """Aggregate per-frame allocations to per-file level.""" + file_allocs = defaultdict(lambda: {"bytes": 0, "count": 0, "functions": set()}) + for (filename, funcname, _lineno), info in source_allocs.items(): + file_allocs[filename]["bytes"] += info["bytes"] + file_allocs[filename]["count"] += info["count"] + file_allocs[filename]["functions"].add(funcname) + return dict(file_allocs) + + +def analyze_snapshot(snapshot): + segments = snapshot.get("segments", []) + total_reserved = sum(s.get("total_size", 0) for s in segments) + total_allocated = sum(s.get("allocated_size", 0) for s in segments) + + # Active blocks + active_blocks = [] + for seg in segments: + for block in seg.get("blocks", []): + if block.get("state") == "active_allocated": + active_blocks.append(block.get("size", 0)) + + # Allocation churn from trace + trace = snapshot.get("device_traces", [[]])[0] + size_counts = defaultdict(lambda: {"count": 0, "total": 0}) + for ev in trace: + if ev.get("action") == "alloc": + sz = ev.get("size", 0) + size_counts[sz]["count"] += 1 + size_counts[sz]["total"] += sz + + # Python frame attribution + source_allocs = _extract_python_frames(snapshot) + file_summary = _extract_source_file_summary(source_allocs) + + return { + "total_reserved": total_reserved, + "total_allocated": total_allocated, + "fragmentation_pct": (total_reserved - total_allocated) / total_reserved * 100 + if total_reserved > 0 + else 0, + "n_segments": len(segments), + "n_active_blocks": len(active_blocks), + "active_bytes": sum(active_blocks), + "largest_active": sorted(active_blocks, reverse=True)[:10], + "alloc_churn": dict( + sorted(size_counts.items(), key=lambda x: x[1]["total"], reverse=True)[:15] + ), + "n_trace_events": len(trace), + "source_allocs": source_allocs, + "file_summary": file_summary, + } + + +def print_memory_analysis(result, label=""): + if label: + print(f"\n{'=' * 75}") + print(f" {label}") + print(f"{'=' * 75}") + + reserved = result["total_reserved"] + allocated = result["total_allocated"] + + print(f"\n Reserved: {reserved / 1e9:.2f} GB") + print(f" Allocated: {allocated / 1e9:.2f} GB") + print(f" Utilization: {allocated / reserved * 100:.1f}%" if reserved > 0 else "") + print(f" Fragmentation: {result['fragmentation_pct']:.1f}%") + print( + f" Segments: {result['n_segments']}, Active blocks: {result['n_active_blocks']}" + ) + + print("\n Largest active allocations:") + for sz in result["largest_active"]: + print(f" {sz / 1e6:>10.1f} MB") + + # Python source file attribution + file_summary = result.get("file_summary", {}) + if file_summary: + print("\n Top allocations by source file:") + print(f" {'Source file':<55} {'Alloc':>10} {'Count':>7}") + print(f" {'-' * 75}") + for fname, info in sorted( + file_summary.items(), key=lambda x: x[1]["bytes"], reverse=True + )[:15]: + # Shorten path for display + short = fname + if len(short) > 55: + parts = short.split("/") + # Keep last 3 path components + short = ".../" + "/".join(parts[-3:]) + if len(short) > 55: + short = short[:52] + "..." + funcs = ", ".join(sorted(info["functions"])[:3]) + sz = info["bytes"] + unit = "MB" + val = sz / 1e6 + if val >= 1000: + unit = "GB" + val = sz / 1e9 + print(f" {short:<55} {val:>8.1f}{unit} {info['count']:>7}") + if funcs: + print(f" functions: {funcs[:70]}") + + # Top allocations by function (more granular) + source_allocs = result.get("source_allocs", {}) + if source_allocs: + print("\n Top allocations by function (with line numbers):") + print(f" {'Function':<35} {'File:Line':<35} {'Size':>10}") + print(f" {'-' * 82}") + for (fname, funcname, lineno), info in sorted( + source_allocs.items(), key=lambda x: x[1]["bytes"], reverse=True + )[:15]: + # Shorten filename + short_file = fname + parts = short_file.split("/") + if len(parts) > 2: + short_file = "/".join(parts[-2:]) + loc = f"{short_file}:{lineno}" + if len(loc) > 35: + loc = "..." + loc[-32:] + sz = info["bytes"] + if sz >= 1e9: + sz_str = f"{sz / 1e9:.2f}GB" + else: + sz_str = f"{sz / 1e6:.1f}MB" + print(f" {funcname:<35} {loc:<35} {sz_str:>10}") + + print("\n Top allocation churn (alloc count x size):") + print(f" {'Size':>12} {'Count':>8} {'Total churned':>14}") + print(f" {'-' * 38}") + for sz, info in result["alloc_churn"].items(): + if sz >= 1e6: + print( + f" {sz / 1e6:>10.1f}MB {info['count']:>8} " + f"{info['total'] / 1e9:>12.2f}GB" + ) + + +# ---- Peak memory timeline from trace events -------------------------------- + + +def analyze_peak_memory(snapshot): + """Walk through device_traces chronologically to find peak concurrent memory usage. + + The snapshot's segment data only captures end-of-step state. The device_traces + record every alloc/free, letting us reconstruct peak usage and identify which + allocation sources were live at that moment. + """ + traces = snapshot.get("device_traces", [[]])[0] + if not traces: + return None + + current = 0 + peak = 0 + peak_idx = 0 + live_allocs = {} # addr -> (size, frames) + peak_live = {} + + for i, ev in enumerate(traces): + action = ev.get("action") + addr = ev.get("addr", 0) + size = ev.get("size", 0) + + if action == "alloc": + current += size + live_allocs[addr] = (size, ev.get("frames", [])) + if current > peak: + peak = current + peak_idx = i + peak_live = dict(live_allocs) + elif action == "free_requested": + if addr in live_allocs: + current -= live_allocs[addr][0] + del live_allocs[addr] + + # Categorize allocations at peak + peak_categories = defaultdict(lambda: {"bytes": 0, "count": 0}) + for _addr, (size, frames) in peak_live.items(): + top = _get_top_python_frame(frames) + if top: + cat = _categorize_source(top["filename"], top["name"]) + else: + cat = "Unknown" + peak_categories[cat]["bytes"] += size + peak_categories[cat]["count"] += 1 + + return { + "peak_bytes": peak, + "peak_event_idx": peak_idx, + "total_events": len(traces), + "end_bytes": current, + "peak_categories": dict(peak_categories), + } + + +def print_peak_memory(result, mem_result=None, label=""): + if not result: + return + + if label: + print(f"\n{'=' * 75}") + print(f" {label}") + print(f"{'=' * 75}") + + peak_gb = result["peak_bytes"] / 1e9 + end_gb = result["end_bytes"] / 1e9 + # The device_traces only record allocations AFTER profiling starts. + # Model weights and other persistent allocations are not tracked. + # We can estimate the persistent baseline from snapshot allocated - peak_traced. + persistent_gb = 0 + if mem_result: + persistent_gb = mem_result["total_allocated"] / 1e9 - end_gb + total_peak_gb = persistent_gb + peak_gb + + print( + f"\n Profiled peak (transient): {peak_gb:.2f} GB " + f"(at event {result['peak_event_idx']:,} / {result['total_events']:,})" + ) + if persistent_gb > 0: + print( + f" Persistent baseline: {persistent_gb:.2f} GB " + f"(model + optimizer, allocated before profiling)" + ) + print(f" Estimated total peak: {total_peak_gb:.2f} GB") + print(f" Transient headroom: {peak_gb - end_gb:.2f} GB above end-of-trace") + + cats = result.get("peak_categories", {}) + if cats: + print("\n Allocations live at peak:") + print(f" {'Category':<35} {'Size':>10} {'Count':>7}") + print(f" {'-' * 55}") + for cat, info in sorted( + cats.items(), key=lambda x: x[1]["bytes"], reverse=True + ): + sz = info["bytes"] + if sz >= 1e9: + sz_str = f"{sz / 1e9:.2f} GB" + else: + sz_str = f"{sz / 1e6:.1f} MB" + print(f" {cat:<35} {sz_str:>10} {info['count']:>7}") + + +# ---- Fragmentation diagnosis ----------------------------------------------- + + +def analyze_fragmentation(snapshot): + """Analyze segment-level memory layout to explain fragmentation. + + Examines each CUDA segment for inactive (freed but unreturned) blocks, + pinned small allocations that prevent segment merging, and the overall + segment size distribution. + """ + segments = snapshot.get("segments", []) + if not segments: + return None + + total_reserved = 0 + total_allocated = 0 + total_inactive = 0 + segment_sizes = [] + inactive_gaps = [] # (gap_size, segment_size, active_around) + pinned_fragments = [] # small active blocks surrounded by inactive + + for seg in segments: + seg_size = seg.get("total_size", 0) + total_reserved += seg_size + segment_sizes.append(seg_size) + blocks = seg.get("blocks", []) + + seg_active = 0 + seg_inactive = 0 + for bi, block in enumerate(blocks): + bsize = block.get("size", 0) + if block.get("state") == "active_allocated": + seg_active += bsize + total_allocated += bsize + # Check if this small block is surrounded by inactive + if bsize < 2 * 1024 * 1024: # < 2MB + prev_inactive = bi > 0 and blocks[bi - 1].get("state") == "inactive" + next_inactive = ( + bi < len(blocks) - 1 + and blocks[bi + 1].get("state") == "inactive" + ) + if prev_inactive and next_inactive: + pinned_fragments.append((bsize, seg_size)) + elif block.get("state") == "inactive": + seg_inactive += bsize + total_inactive += bsize + inactive_gaps.append((bsize, seg_size)) + + # Classify segment sizes + size_buckets = defaultdict(lambda: {"count": 0, "total": 0}) + for sz in segment_sizes: + if sz >= 1024 * 1024 * 1024: + bucket = ">=1 GB" + elif sz >= 256 * 1024 * 1024: + bucket = "256MB-1GB" + elif sz >= 64 * 1024 * 1024: + bucket = "64-256MB" + elif sz >= 2 * 1024 * 1024: + bucket = "2-64MB" + else: + bucket = "<2MB" + size_buckets[bucket]["count"] += 1 + size_buckets[bucket]["total"] += sz + + # Large inactive gaps that could be reclaimed + inactive_gaps.sort(key=lambda x: x[0], reverse=True) + + return { + "total_reserved": total_reserved, + "total_allocated": total_allocated, + "total_inactive": total_inactive, + "n_segments": len(segments), + "segment_size_buckets": dict(size_buckets), + "large_inactive_gaps": inactive_gaps[:20], + "pinned_fragments": len(pinned_fragments), + "expandable_segments_would_help": ( + total_inactive > 0.1 * total_reserved and len(segments) > 10 + ), + } + + +def print_fragmentation(result, gpu_capacity_gb=None, label=""): + if not result: + return + + if label: + print(f"\n{'=' * 75}") + print(f" {label}") + print(f"{'=' * 75}") + + reserved = result["total_reserved"] + allocated = result["total_allocated"] + inactive = result["total_inactive"] + frag_pct = inactive / reserved * 100 if reserved > 0 else 0 + + print( + f"\n Reserved: {reserved / 1e9:.2f} GB across {result['n_segments']} segments" + ) + print(f" Allocated: {allocated / 1e9:.2f} GB") + print(f" Inactive: {inactive / 1e9:.2f} GB ({frag_pct:.1f}% fragmentation)") + + if result["pinned_fragments"] > 0: + print( + f" Pinned small blocks (<2MB between inactive): " + f"{result['pinned_fragments']} (prevent segment merging)" + ) + + # Segment size distribution + print("\n Segment size distribution:") + bucket_order = [">=1 GB", "256MB-1GB", "64-256MB", "2-64MB", "<2MB"] + for bucket in bucket_order: + info = result["segment_size_buckets"].get(bucket) + if info: + print( + f" {bucket:<12} {info['count']:>4} segments " + f"{info['total'] / 1e9:>6.2f} GB" + ) + + # Largest inactive gaps + gaps = result.get("large_inactive_gaps", []) + if gaps: + print("\n Largest inactive gaps (freed but unreclaimable):") + shown = 0 + for gap_sz, seg_sz in gaps: + if gap_sz >= 32 * 1024 * 1024 and shown < 10: + print( + f" {gap_sz / 1e6:>8.0f} MB gap in {seg_sz / 1e6:.0f} MB segment" + ) + shown += 1 + + # OOM risk assessment + if gpu_capacity_gb: + gpu_bytes = gpu_capacity_gb * 1e9 + usable = gpu_bytes - (reserved - allocated) # capacity minus fragmented waste + print(f"\n OOM Risk Assessment (GPU: {gpu_capacity_gb:.1f} GB):") + print( + f" Usable capacity: {usable / 1e9:.2f} GB " + f"(GPU capacity minus {inactive / 1e9:.2f} GB fragmentation)" + ) + headroom = gpu_bytes - reserved + print(f" Current headroom: {headroom / 1e9:.2f} GB") + if headroom < 1.0e9: + print(" ⚠ CRITICAL: <1 GB headroom — high OOM risk!") + elif headroom < 2.0e9: + print(" ⚠ WARNING: <2 GB headroom — moderate OOM risk") + + # Recommendation + if result.get("expandable_segments_would_help"): + print("\n → FIX: Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True") + print(" This eliminates segment fragmentation by growing segments in-place,") + print( + f" which would reclaim up to {inactive / 1e9:.1f} GB of wasted memory." + ) + + +# ---- Sequence-length scaling analysis -------------------------------------- + + +def analyze_scaling(mem_a, mem_b, churn_a, churn_b): + """Compare per-tensor allocation sizes between two runs. + + When two profiles differ only in sequence length, this shows which tensor + categories scale with sequence length by comparing the dominant tensor sizes. + Total churn may differ due to different profiling windows, so we focus on + per-tensor size ratios instead. + """ + if not churn_a or not churn_b: + return None + + def _cat_sizes(churn): + """Map category -> {largest_size, total_bytes, count}.""" + cat_data = defaultdict(lambda: {"max_size": 0, "sizes": [], "count": 0}) + for sz, info in churn.items(): + for cat, cinfo in info.get("categories", {}).items(): + cat_data[cat]["count"] += cinfo["count"] + cat_data[cat]["sizes"].append((sz, cinfo["count"])) + if sz > cat_data[cat]["max_size"]: + cat_data[cat]["max_size"] = sz + return dict(cat_data) + + cats_a = _cat_sizes(churn_a) + cats_b = _cat_sizes(churn_b) + + all_cats = sorted( + set(list(cats_a) + list(cats_b)), + key=lambda c: max( + cats_a.get(c, {"max_size": 0})["max_size"], + cats_b.get(c, {"max_size": 0})["max_size"], + ), + reverse=True, + ) + + scaling = [] + for cat in all_cats: + a = cats_a.get(cat) + b = cats_b.get(cat) + if not a or not b: + continue + a_max = a["max_size"] + b_max = b["max_size"] + if a_max > 1e6 and b_max > 1e6: # Only compare >1MB tensors + tensor_ratio = b_max / a_max if a_max > 0 else None + scaling.append( + { + "category": cat, + "size_a_mb": a_max / 1e6, + "size_b_mb": b_max / 1e6, + "tensor_ratio": tensor_ratio, + "count_a": a["count"], + "count_b": b["count"], + "scales_with_seqlen": tensor_ratio is not None + and tensor_ratio > 1.05, + } + ) + + scaling.sort(key=lambda x: x["size_b_mb"], reverse=True) + return scaling + + +def print_scaling(scaling, label_a="Before", label_b="After", label=""): + if not scaling: + return + + if label: + print(f"\n{'=' * 75}") + print(f" {label}") + print(f"{'=' * 75}") + + print("\n Per-tensor size comparison (largest tensor per category):") + print( + f" {'Category':<35} {'A size':>10} {'B size':>10} {'Ratio':>7} {'Scales?':>8}" + ) + print(f" {'-' * 73}") + for entry in scaling: + ratio_str = f"{entry['tensor_ratio']:.2f}x" if entry["tensor_ratio"] else "N/A" + scales = "YES" if entry["scales_with_seqlen"] else "no" + print( + f" {entry['category']:<35} {entry['size_a_mb']:>8.1f}MB " + f"{entry['size_b_mb']:>8.1f}MB {ratio_str:>7} {scales:>8}" + ) + + # Summary + seq_scaling = [e for e in scaling if e["scales_with_seqlen"]] + constant = [e for e in scaling if not e["scales_with_seqlen"]] + if seq_scaling: + ratios = [e["tensor_ratio"] for e in seq_scaling if e["tensor_ratio"]] + avg_ratio = sum(ratios) / len(ratios) if ratios else 0 + print(f"\n Sequence-length scaling detected ({avg_ratio:.2f}x avg):") + for e in seq_scaling: + print( + f" - {e['category']}: {e['size_a_mb']:.1f}MB -> " + f"{e['size_b_mb']:.1f}MB ({e['tensor_ratio']:.2f}x)" + ) + if constant: + print("\n Constant-size categories (do not scale with seq len):") + for e in constant: + print(f" - {e['category']}: {e['size_a_mb']:.1f}MB") + + +# ---- Main ----------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze axolotl training profiler output", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "path", + help="Path to output directory (containing profiler_trace.json and/or " + "snapshot.pickle) or directly to a trace file. " + "Security note: snapshot.pickle uses pickle deserialization — " + "only use files from your own trusted training runs.", + ) + parser.add_argument( + "--compare", + help="Path to second run for A/B comparison. " + "Same security note as path: only use trusted snapshot files.", + ) + parser.add_argument( + "--include-warmup", + action="store_true", + help="Include step 0 (warmup/compilation) in timing analysis", + ) + parser.add_argument( + "--memory-only", + action="store_true", + help="Only analyze memory snapshot, skip trace", + ) + parser.add_argument( + "--quick", + action="store_true", + help="Only load first 2M events for rapid analysis of large traces", + ) + parser.add_argument( + "--gpu-gb", + type=float, + default=None, + help="GPU total memory in GB (for OOM risk assessment). " + "Auto-detected if not specified.", + ) + args = parser.parse_args() + + # Auto-detect GPU capacity if not specified + gpu_capacity_gb = args.gpu_gb + if gpu_capacity_gb is None: + try: + import torch + + if torch.cuda.is_available(): + gpu_capacity_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + except Exception: + pass + + skip = not args.include_warmup + trace_result = None + mem_result = None + events = None + + # -- Trace analysis -- + if not args.memory_only: + events = load_trace(args.path, quick=args.quick) + if events: + trace_result = analyze_trace(events, skip_warmup=skip) + if trace_result: + print_trace_analysis(trace_result, label=f"Trace: {args.path}") + + if args.compare: + events2 = load_trace(args.compare, quick=args.quick) + if events2: + result2 = analyze_trace(events2, skip_warmup=skip) + if result2: + print_trace_analysis( + result2, label=f"Trace: {args.compare}" + ) + compare_traces(trace_result, result2) + else: + print(f" No profiler_trace.json found in {args.path}") + + # -- CPU overhead analysis (from trace) -- + if events and trace_result: + cpu_result = analyze_cpu_overhead(events) + print_cpu_overhead( + cpu_result, + n_steps=trace_result["n_steps"], + label=f"CPU Overhead: {args.path}", + ) + + # -- Memory analysis -- + snapshot = load_snapshot(args.path) + churn_result = None + churn2 = None + if snapshot: + mem_result = analyze_snapshot(snapshot) + print_memory_analysis(mem_result, label=f"Memory: {args.path}") + + # Peak memory timeline + peak_result = analyze_peak_memory(snapshot) + print_peak_memory( + peak_result, mem_result=mem_result, label=f"Peak Memory: {args.path}" + ) + + # Fragmentation diagnosis + frag_result = analyze_fragmentation(snapshot) + print_fragmentation( + frag_result, + gpu_capacity_gb=gpu_capacity_gb, + label=f"Fragmentation: {args.path}", + ) + + # Allocation churn attribution + churn_result = analyze_allocation_churn(snapshot) + if churn_result: + print_allocation_churn(churn_result, label=f"Allocation Churn: {args.path}") + + if args.compare: + snapshot2 = load_snapshot(args.compare) + if snapshot2: + mem2 = analyze_snapshot(snapshot2) + print_memory_analysis(mem2, label=f"Memory: {args.compare}") + + # Peak memory for comparison + peak2 = analyze_peak_memory(snapshot2) + print_peak_memory( + peak2, mem_result=mem2, label=f"Peak Memory: {args.compare}" + ) + + # Fragmentation for comparison + frag2 = analyze_fragmentation(snapshot2) + print_fragmentation( + frag2, + gpu_capacity_gb=gpu_capacity_gb, + label=f"Fragmentation: {args.compare}", + ) + + churn2 = analyze_allocation_churn(snapshot2) + if churn2: + print_allocation_churn( + churn2, label=f"Allocation Churn: {args.compare}" + ) + + # Memory comparison summary + print("\n Memory comparison:") + print( + f" Reserved: {mem_result['total_reserved'] / 1e9:.2f} -> " + f"{mem2['total_reserved'] / 1e9:.2f} GB" + ) + print( + f" Allocated: {mem_result['total_allocated'] / 1e9:.2f} -> " + f"{mem2['total_allocated'] / 1e9:.2f} GB" + ) + print( + f" Frag: {mem_result['fragmentation_pct']:.1f}% -> " + f"{mem2['fragmentation_pct']:.1f}%" + ) + if peak_result and peak2: + print( + f" Peak: {peak_result['peak_bytes'] / 1e9:.2f} -> " + f"{peak2['peak_bytes'] / 1e9:.2f} GB" + ) + + # Scaling analysis + if churn_result and churn2: + scaling = analyze_scaling(mem_result, mem2, churn_result, churn2) + print_scaling( + scaling, + label_a=str(args.path), + label_b=str(args.compare), + label="Allocation Scaling Analysis", + ) + elif not args.memory_only: + pass # trace-only is fine + else: + print(f" No snapshot.pickle found in {args.path}") + + # -- Summary -- + if trace_result: + print_summary(trace_result, mem_result=mem_result) + + +if __name__ == "__main__": + main() diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 96183973f..d2f2e414c 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -435,6 +435,23 @@ class AxolotlTrainer( num_items_in_batch=num_items_in_batch, ) + # Gemma4ForConditionalGeneration computes loss with a manual + # nn.CrossEntropyLoss() that bypasses proper num_items_in_batch + # normalization and does redundant attention_mask filtering. + # Compute loss externally using the standard loss_function instead. + if _model_type == "gemma4" and "labels" in inputs: + labels = inputs.pop("labels") + outputs = model(**inputs) + logits = outputs.logits + unwrapped = self.accelerator.unwrap_model(model) + vocab_size = unwrapped.config.get_text_config().vocab_size + loss = unwrapped.loss_function( + logits, labels, vocab_size, num_items_in_batch=num_items_in_batch + ) + if return_outputs: + return loss, outputs + return loss + return super().compute_loss( model, inputs, diff --git a/src/axolotl/integrations/liger/plugin.py b/src/axolotl/integrations/liger/plugin.py index d56109570..9c4a26351 100644 --- a/src/axolotl/integrations/liger/plugin.py +++ b/src/axolotl/integrations/liger/plugin.py @@ -222,6 +222,56 @@ class LigerPlugin(BasePlugin): rms_norm=cfg.liger_rms_norm, swiglu=cfg.liger_glu_activation, ) + elif cfg.model_config_type in ("gemma4", "gemma4_text"): + # Gemma4: offset=0 (NOT 1 like Gemma3), in_place=False required for + # gradient checkpointing compatibility, RoPE incompatible (separate q/k). + from liger_kernel.transformers.geglu import LigerGEGLUMLP + from transformers.models.gemma4 import modeling_gemma4 + + if cfg.liger_rms_norm: + _OrigGemma4RMSNorm = modeling_gemma4.Gemma4RMSNorm + + class _LigerGemma4RMSNorm(LigerRMSNorm): + """LigerRMSNorm for Gemma4 with in_place=False and with_scale support.""" + + def __new__(cls, dim, eps=1e-6, with_scale=True): + if not with_scale: + return _OrigGemma4RMSNorm(dim, eps, with_scale=False) + return super().__new__(cls) + + def __init__(self, dim, eps=1e-6, with_scale=True): + if not with_scale: + return + # offset=0.0 (standard), in_place=False (gradient checkpointing safe) + super().__init__( + dim, eps, offset=0.0, casting_mode="llama", in_place=False + ) + + modeling_gemma4.Gemma4RMSNorm = _LigerGemma4RMSNorm + if cfg.liger_glu_activation: + + class _LigerGemma4MLP(LigerGEGLUMLP): + def __init__(self, config, layer_idx=None): + super().__init__(config) + + modeling_gemma4.Gemma4TextMLP = _LigerGemma4MLP + if cfg.liger_rope: + LOG.warning( + "Liger RoPE is not compatible with Gemma4 (separate q/k application). Skipping." + ) + if cfg.liger_layer_norm: + modeling_gemma4.nn.LayerNorm = LigerLayerNorm + if cfg.liger_cross_entropy: + modeling_gemma4.nn.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + LOG.warning( + "Liger fused linear cross entropy is not compatible with Gemma4. Skipping." + ) + LOG.info( + f"Applied Liger kernels for gemma4: " + f"rms_norm={cfg.liger_rms_norm}, glu={cfg.liger_glu_activation}, " + f"rope=False (incompatible), layer_norm={cfg.liger_layer_norm}" + ) elif cfg.liger_fused_linear_cross_entropy: try: from .models.base import patch_lce_forward diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 9cf65286a..b9e7f3e0c 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -112,6 +112,47 @@ QKV_PATCHES = [ else: key_states = key_states.view(hidden_shape) value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states +""".lstrip("\n"), + ), + # Gemma4 (transformers >= 5.6): shared_kv_states parameter replaces + # past_key_values.shared_layers, and v_norm added after k_norm. + ( + """ + query_states = self.q_proj(hidden_states).view(hidden_shape) + query_states = self.q_norm(query_states) + query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) + query_states = query_states.transpose(1, 2) + + # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer. + # We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache + # once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None + if self.is_kv_shared_layer: + key_states, value_states = shared_kv_states[self.kv_shared_layer_index] + # Device of past layer may be different from current one + key_states = key_states.to(query_states.device) + value_states = value_states.to(query_states.device) + else: + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states +""".lstrip("\n"), + """ + query_states, key_states, value_states = self.apply_qkv(hidden_states) + query_states = query_states.view(hidden_shape) + query_states = self.q_norm(query_states) + query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) + query_states = query_states.transpose(1, 2) + + # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer. + # We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache + # once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None + if self.is_kv_shared_layer: + key_states, value_states = shared_kv_states[self.kv_shared_layer_index] + # Device of past layer may be different from current one + key_states = key_states.to(query_states.device) + value_states = value_states.to(query_states.device) + else: + key_states = key_states.view(hidden_shape) + value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states """.lstrip("\n"), ), ]