--- title: "Training Stability & Debugging" order: 15 description: "Guide to monitoring, debugging, and stabilizing training runs in axolotl" --- This guide covers practical techniques for monitoring training health, diagnosing instability, and resolving common failures in both supervised fine-tuning (SFT) and reinforcement learning (GRPO/EBFT) workflows. ## Monitoring Training ### Key Metrics for SFT Every SFT run should be monitored through at least these four metrics: | Metric | What It Tells You | Healthy Range | |--------|-------------------|---------------| | `train/loss` | How well the model fits training data | Decreasing; typically 0.5--2.0 for chat fine-tuning | | `eval/loss` | Generalization performance | Tracks train loss with small gap; divergence signals overfitting | | `grad_norm` | Gradient magnitude | 0.1--10.0; spikes above 100 indicate instability | | `learning_rate` | Current LR from scheduler | Should follow expected schedule (warmup then decay) | ::: {.callout-tip} ## Set Up Logging Early Enable W&B or TensorBoard from the start. Debugging a failed run without metrics is guesswork. ```yaml wandb_project: my-project wandb_run_id: # optional, for resuming logging_steps: 1 ``` ::: ### Key Metrics for RL (GRPO) GRPO training logs a richer set of metrics. These are the critical ones: | Metric | Healthy Range | Red Flag | |--------|---------------|----------| | `rewards//mean` | > 0.15 within 20 steps | Stays at 0 -- reward function is broken or task is too hard | | `reward_std` | > 0 on most steps | Always 0 -- no learning signal (all completions get the same reward) | | `frac_reward_zero_std` | < 0.8 | 1.0 on every step -- zero-advantage skip fires constantly, no gradient updates | | `grad_norm` | 0.001--1.0 | 0.0 is acceptable occasionally (zero-adv skip); > 10.0 is unstable | | `entropy` | 0.05--0.5 | < 0.01 suggests mode collapse; > 1.0 suggests the model is not converging | | `kl` | 0.0--0.5 | > 2.0 suggests policy has diverged too far from reference | | `sampling/sampling_logp_difference/mean` | < 0.1 | > 1.0 means policy has diverged far from vLLM server weights | | `sampling/importance_sampling_ratio/min` | > 0.1 | Near 0 indicates stale off-policy data; increase `vllm_sync_interval` | | `clip_ratio/region_mean` | < 0.1 | > 0.3 means PPO clipping is too aggressive | | `completions/mean_length` | Task-dependent | Monotonically increasing to max length suggests reward hacking | | `completions/clipped_ratio` | < 0.3 | > 0.8 means most completions hit `max_completion_length` -- increase it | ::: {.callout-note} ## EBFT-Specific Metrics For EBFT training, also monitor `ebft/alignment` (should trend upward, healthy 0.3--0.9), `ebft/diversity` (healthy 0.01--0.1; > 1.0 indicates mode collapse), and `ebft/cfm_loss` (should trend downward, < 10). ::: ## SFT Stability ### Loss Plateau **Symptom**: Loss stops decreasing early in training, well above expected values. **Causes and fixes**: - **Learning rate too low**: Increase by 2--5x. Typical ranges: full fine-tune 1e-5 to 5e-5, LoRA 1e-4 to 3e-4. - **Insufficient warmup**: Set `warmup_steps` to 5--10% of total steps. Too-aggressive learning at the start can push the model into a flat region. - **Data quality**: Check that labels are correctly masked. Use `axolotl preprocess` and inspect tokenized samples to confirm only the target tokens are trainable. - **Weight decay too high**: Default 0.01 is usually fine. Values above 0.1 can suppress learning in LoRA. ### Loss Spikes **Symptom**: Loss suddenly jumps by 2--10x then (possibly) recovers. **Causes and fixes**: - **Bad data samples**: A single malformed or extremely long example can cause a spike. Enable `sample_packing: false` temporarily and check if spikes correlate with specific batches. - **Learning rate too high**: Reduce by 2--5x, or increase warmup. - **Gradient accumulation mismatch**: Effective batch size = `micro_batch_size * gradient_accumulation_steps * num_gpus`. Very large effective batch sizes amplify gradient noise. - **Mixed precision issues**: With `bf16: true`, some operations can lose precision. If spikes are severe, try `fp32` for diagnosis. ### Overfitting **Symptom**: Train loss keeps decreasing but eval loss starts increasing. **Fixes**: - Increase `val_set_size` (e.g., 0.05) and monitor `eval/loss`. - Reduce `num_epochs` or `max_steps`. - Increase `weight_decay` (try 0.01--0.1). - Use a smaller LoRA rank (`lora_r`). Typical values: 8--32. - Increase dropout: `lora_dropout: 0.05`. ## RL/GRPO Stability ### Reward Never Increases If `rewards/*/mean` stays at 0 for more than 20 steps: 1. **Test reward function standalone**: Run it outside training with known inputs to verify it returns nonzero values. ```bash cd experiments && python -c "import my_rewards; print(my_rewards.accuracy_reward(...))" ``` 2. **Check dataset columns**: The reward function receives `**kwargs` containing dataset columns. Verify the columns it needs (e.g., `answer`) are not removed by the dataset transform. 3. **Check completion content**: Enable `log_completions: true` in the `trl:` config and inspect logged completions in W&B. If completions are empty or incoherent, the model may be too weak for the task. 4. **Verify vLLM is serving the right model**: Hit the vLLM health endpoint and confirm the model name matches your config. ### Entropy Collapse (Mode Collapse) **Symptom**: `entropy` drops below 0.01; all completions become nearly identical. **Fixes**: - Increase `temperature` in generation kwargs (try 0.8--1.0). - Reduce learning rate. - Add a KL penalty term (`beta` parameter in GRPO config). - Check that `num_generations` is sufficient (16+ gives better advantage estimates). ### IS Ratio Divergence **Symptom**: `sampling/importance_sampling_ratio/min` drops near 0, or `sampling/sampling_logp_difference/mean` exceeds 1.0. This means the policy has diverged significantly from the weights used by vLLM for generation. The importance sampling correction becomes unreliable. **Fixes**: - Decrease `vllm_sync_interval` (sync weights more often). - Enable `off_policy_mask_threshold` (e.g., 0.5) to mask stale off-policy samples. - Use `importance_sampling_level: token` for finer-grained correction. ### Gradient Norm Instability **Symptom**: `grad_norm` oscillates wildly or exceeds 10.0 regularly. **Fixes**: - Enable gradient clipping: `max_grad_norm: 1.0` (default in most configs). - Reduce learning rate. - Increase `gradient_accumulation_steps` to smooth out noisy batches. - Check for NaN issues (see next section). ## NaN and Inf Handling ### Common Causes | Cause | Where It Manifests | Detection | |-------|-------------------|-----------| | FP8 zero-scale division | Forward pass logits | `grad_norm: nan`, loss becomes NaN immediately | | Gradient explosion | Backward pass | `grad_norm` spikes to inf, then loss goes NaN | | Bad data (empty sequences) | Logprob computation | NaN in specific batches only | | Numerical overflow in log-softmax | Loss computation | Large negative logprobs cause exp() overflow | ### FP8-Specific NaN Issues FP8 quantization (`fp8: true`) can produce NaN when the activation quantization kernel divides by `max(abs(x)) / 448`. If the input tensor is all zeros (e.g., padding positions), the scale becomes 0, causing division by zero. **Fixes applied in axolotl**: - The `act_quant_kernel` has a zero-guard: `s = tl.where(s == 0, 1.0, s)`. - A safety net `nan_to_num(logits, nan=0.0)` is applied in `_get_per_token_logps_and_entropies`. - Embedding padding is zero-padded for FP8 compatibility. ::: {.callout-important} ## After Modifying Triton Kernels If you patch any Triton JIT kernel (e.g., the FP8 quantization kernels in transformers), you must clear the Triton cache for changes to take effect: ```bash rm -rf ~/.triton/cache ``` ::: ### General NaN Debugging Steps 1. **Enable anomaly detection** (slow, but pinpoints the source): ```python torch.autograd.set_detect_anomaly(True) ``` 2. **Check grad_norm**: If it goes to NaN, the backward pass is the problem. If loss is NaN but grad_norm was fine on the previous step, the forward pass is the problem. 3. **Reduce to single GPU, single batch**: Eliminate distributed training variables. 4. **Inspect data**: Print the batch that triggers NaN. Look for empty sequences, extreme token IDs, or unexpected padding patterns. ## OOM Debugging Out-of-memory errors are the most common training failure. Use this systematic approach, from least to most disruptive: ### Step 1: Reduce Batch Size The single highest-impact change. VRAM scales roughly linearly with batch size. ```yaml micro_batch_size: 1 # Start here gradient_accumulation_steps: 16 # Increase to maintain effective batch size ``` For GRPO specifically, the logits tensor for policy logprob computation can be very large. `batch_size * num_generations * seq_len * vocab_size` in bf16. For example, with `num_generations: 16` and `micro_batch_size: 8`, the logits tensor alone is: ``` 8 * 16 * 2048 * 151936 * 2 bytes = ~75 GB (way too large) ``` Reduce `micro_batch_size` to 2--4 for GRPO. ### Step 2: Enable Gradient Checkpointing Trades compute for memory by recomputing activations during the backward pass instead of storing them. ```yaml gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false # Recommended default ``` ::: {.callout-warning} ## Reentrant Checkpointing Exceptions Some configurations require `use_reentrant: true`: - DeepSpeed ZeRO-3 (non-reentrant causes `CheckpointError`) - EBFT strided mode with flex_attention ::: ### Step 3: Use Quantization Load the base model in reduced precision: ```yaml # 4-bit QLoRA adapter: qlora load_in_4bit: true # 8-bit load_in_8bit: true # FP8 (saves ~50% model VRAM, same compute speed as bf16) fp8: true ``` ### Step 4: Reduce Sequence Length ```yaml sequence_len: 1024 # Down from 2048 or 4096 ``` For GRPO, also reduce `max_completion_length`. Memory scales quadratically with sequence length when using standard attention. ### Step 5: Use Flash Attention Reduces attention memory from O(n^2) to O(n): ```yaml attn_implementation: flash_attention_2 ``` ### Step 6: Offload with DeepSpeed For extreme cases, offload optimizer states or parameters to CPU: ```yaml deepspeed: deepspeed_configs/zero3_bf16.json ``` ### Diagnosing the Specific Culprit Use the `profiler_steps` config option to capture GPU memory snapshots: ```yaml profiler_steps: [1, 2] ``` This generates PyTorch profiler traces you can inspect to see exactly which tensor allocation caused the OOM. ## Common Errors | Error Message | Likely Cause | Fix | |---------------|-------------|-----| | `exitcode: -9` | System RAM exhaustion | Reduce dataset size, `dataset_num_proc`, or number of data workers | | `exitcode: -7` (DeepSpeed) | DeepSpeed version issue | `pip install -U deepspeed` | | `CUDA out of memory` | GPU VRAM exhaustion | Follow OOM debugging steps above | | `RuntimeError: NCCL communicator was aborted` | GPU communication failure | See [NCCL docs](nccl.qmd); check `NCCL_DEBUG=INFO` output | | `ValueError: Asking to pad but the tokenizer does not have a padding token` | Missing pad token | Add `special_tokens: { pad_token: "<\|endoftext\|>" }` to config | | `'DummyOptim' object has no attribute 'step'` | DeepSpeed on single GPU | Remove `deepspeed:` section from config | | `unable to load strategy X` then `None is not callable` | Reward module not importable | Run `cd experiments && python -c "import my_rewards"` to check | | `generation_batch_size not divisible by num_generations` | micro_batch_size too small | Set `micro_batch_size >= num_generations` and make it divisible | | `'weight' must be 2-D` | FSDP1 flattened parameters | Use `fsdp_version: 2` or skip `unwrap_model` when FSDP is enabled | | `CheckpointError` (tensor count mismatch) | Non-reentrant checkpointing + ZeRO-3 or flex_attention | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` | | `BFloat16` TypeError during weight sync | NumPy does not support bf16 | Fixed in axolotl's `weight_serde.py` (auto bf16 to fp16 conversion) | | `Content end boundary is before start boundary` | Chat template parsing issue | Check `eos_token` matches template; file a GitHub issue if persistent | | `CAS service error` during data processing | HuggingFace XET issue | Set `export HF_HUB_DISABLE_XET=1` | | Training hangs (multi-GPU) | FSDP + async prefetch deadlock | Set `async_prefetch: false` with FSDP | ## Profiling ### PyTorch Profiler Axolotl supports PyTorch profiler integration via the config: ```yaml profiler_steps: [1, 2, 3] ``` This captures profiler traces for the specified steps. View them in TensorBoard: ```bash tensorboard --logdir output_dir/runs ``` Or open the `.json` trace file in `chrome://tracing`. ### CUDA Memory Snapshots For detailed memory analysis, use PyTorch's memory snapshot API. Add this to your training script or use it interactively: ```python import torch # Enable memory history tracking torch.cuda.memory._record_memory_history() # ... run your training step ... # Save snapshot torch.cuda.memory._dump_snapshot("memory_snapshot.pickle") ``` Visualize with PyTorch's memory visualizer: ```bash python -m torch.cuda.memory._viz memory_snapshot.pickle ``` ### Quick GPU Memory Check During training, monitor GPU utilization in a separate terminal: ```bash watch -n 1 nvidia-smi ``` For programmatic access within axolotl, the logged metrics `memory/max_alloc` and `memory/max_reserved` come from `torch.cuda.max_memory_allocated()` and `torch.cuda.max_memory_reserved()`. Note these report PyTorch's view of memory, which may differ from `nvidia-smi` (see [FAQ](faq.qmd)). ## W&B and Logging ### Enabling Logging ```yaml wandb_project: my-project wandb_entity: my-team # optional wandb_run_id: run-123 # optional, for resuming wandb_name: experiment-name # optional logging_steps: 1 # log every step (recommended for RL) ``` ### Debug Logging For detailed axolotl-internal debug output: ```bash AXOLOTL_LOG_LEVEL=DEBUG axolotl train config.yaml 2>&1 | tee /tmp/training.log ``` ::: {.callout-tip} ## Always Log to a File Pipe training output to a log file so you can inspect it after the run: ```bash axolotl train config.yaml 2>&1 | tee /tmp/my_run.log ``` ::: ### What Axolotl Logs **SFT metrics** (logged every `logging_steps`): - `train/loss`, `eval/loss` -- training and validation loss - `train/grad_norm` -- gradient L2 norm (before clipping) - `train/learning_rate` -- current learning rate - `memory/max_alloc`, `memory/max_reserved` -- peak GPU memory **GRPO/RL metrics** (logged every step): - `rewards//mean`, `rewards//std` -- per-reward-function statistics - `reward`, `reward_std` -- aggregated reward across all reward functions - `frac_reward_zero_std` -- fraction of prompt groups where all completions got the same reward - `completions/mean_length`, `completions/min_length`, `completions/max_length` -- completion token lengths - `completions/clipped_ratio` -- fraction of completions that hit the max length - `completions/mean_terminated_length`, `completions/min_terminated_length`, `completions/max_terminated_length` -- lengths of naturally terminated completions - `kl` -- KL divergence between policy and reference - `entropy` -- policy entropy (measure of output diversity) - `clip_ratio/region_mean`, `clip_ratio/low_mean`, `clip_ratio/high_mean` -- PPO clipping statistics - `sampling/sampling_logp_difference/mean`, `sampling/sampling_logp_difference/max` -- log-probability difference between policy and sampling distribution - `sampling/importance_sampling_ratio/min`, `sampling/importance_sampling_ratio/mean`, `sampling/importance_sampling_ratio/max` -- IS ratio statistics for off-policy correction - `num_tokens` -- total tokens processed ### Reading W&B Charts For a healthy GRPO run, expect to see: 1. **`reward/mean`**: Gradual upward trend. May start near 0 and reach 0.3--0.8 depending on task difficulty. Not monotonic -- fluctuations are normal. 2. **`entropy`**: Gradual decrease from initial values (often 0.3--0.6) as the model becomes more confident. Should not collapse to near-zero. 3. **`grad_norm`**: Mostly in the 0.001--1.0 range. Occasional 0.0 values are fine (zero-advantage skip). Persistent values above 10.0 need investigation. 4. **`kl`**: Starts near 0 and grows slowly. If it shoots up rapidly, the policy is diverging from the reference. 5. **`completions/mean_length`**: Should reflect the task's natural answer length. If it steadily increases to `max_completion_length`, the model may be reward-hacking by generating longer outputs.