* docs: comprehensive documentation improvements for humans and agents New human docs: - grpo.qmd: GRPO deep dive (async, rewards, IS correction, scaling) - ebft.qmd: EBFT guide (structured/strided modes, feature extraction) - choosing_method.qmd: decision tree for SFT vs LoRA vs DPO vs GRPO - vllm_serving.qmd: vLLM setup for GRPO (server/colocate, LoRA sync) - training_stability.qmd: monitoring, NaN debugging, OOM, healthy metrics New agent docs: - AGENTS_SFT.md: agent reference for supervised fine-tuning - AGENTS_DPO.md: agent reference for preference learning (DPO/KTO/ORPO) Updated existing docs: - rlhf.qmd: cross-references to new GRPO/EBFT/choosing-method guides - getting-started.qmd: reorganized Next Steps with links to new guides - debugging.qmd: link to training stability guide - _quarto.yml: added new pages to sidebar navigation Removed: - bak.agents.md: stale backup that confused agents * docs: trim duplicated generic config from AGENTS_DPO.md Remove boilerplate training params (optimizer, gradient_checkpointing, flash_attention, etc.) from each method template. These are not preference-learning-specific and are already covered in AGENTS_SFT.md. Config templates now show only method-specific fields with a reference to AGENTS_SFT.md for the rest. * docs: deduplicate across new doc pages - grpo.qmd: collapse vLLM setup section to brief config + link to vllm_serving.qmd; collapse IS correction to essentials + link; replace full monitoring tables with summary + link to training_stability.qmd - vllm_serving.qmd: remove duplicated async/IS config reference tables (already in grpo.qmd config reference); replace full example config with link to grpo.qmd quick start - ebft.qmd: trim generic training params in quick start config * fix: train scripts * feat: split files into cleaner parts * fix: cleanup pretraining docs --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
400 lines
16 KiB
Plaintext
400 lines
16 KiB
Plaintext
---
|
|
title: "Training Stability & Debugging"
|
|
order: 15
|
|
description: "Guide to monitoring, debugging, and stabilizing training runs in axolotl"
|
|
---
|
|
|
|
This guide covers practical techniques for monitoring training health, diagnosing instability, and resolving common failures in both supervised fine-tuning (SFT) and reinforcement learning (GRPO/EBFT) workflows.
|
|
|
|
## Monitoring Training
|
|
|
|
### Key Metrics for SFT
|
|
|
|
Every SFT run should be monitored through at least these four metrics:
|
|
|
|
| Metric | What It Tells You | Healthy Range |
|
|
|--------|-------------------|---------------|
|
|
| `train/loss` | How well the model fits training data | Decreasing; typically 0.5--2.0 for chat fine-tuning |
|
|
| `eval/loss` | Generalization performance | Tracks train loss with small gap; divergence signals overfitting |
|
|
| `grad_norm` | Gradient magnitude | 0.1--10.0; spikes above 100 indicate instability |
|
|
| `learning_rate` | Current LR from scheduler | Should follow expected schedule (warmup then decay) |
|
|
|
|
::: {.callout-tip}
|
|
## Set Up Logging Early
|
|
Enable W&B or TensorBoard from the start. Debugging a failed run without metrics is guesswork.
|
|
|
|
```yaml
|
|
wandb_project: my-project
|
|
wandb_run_id: # optional, for resuming
|
|
logging_steps: 1
|
|
```
|
|
:::
|
|
|
|
### Key Metrics for RL (GRPO)
|
|
|
|
GRPO training logs a richer set of metrics. These are the critical ones:
|
|
|
|
| Metric | Healthy Range | Red Flag |
|
|
|--------|---------------|----------|
|
|
| `rewards/<name>/mean` | > 0.15 within 20 steps | Stays at 0 -- reward function is broken or task is too hard |
|
|
| `reward_std` | > 0 on most steps | Always 0 -- no learning signal (all completions get the same reward) |
|
|
| `frac_reward_zero_std` | < 0.8 | 1.0 on every step -- zero-advantage skip fires constantly, no gradient updates |
|
|
| `grad_norm` | 0.001--1.0 | 0.0 is acceptable occasionally (zero-adv skip); > 10.0 is unstable |
|
|
| `entropy` | 0.05--0.5 | < 0.01 suggests mode collapse; > 1.0 suggests the model is not converging |
|
|
| `kl` | 0.0--0.5 | > 2.0 suggests policy has diverged too far from reference |
|
|
| `sampling/sampling_logp_difference/mean` | < 0.1 | > 1.0 means policy has diverged far from vLLM server weights |
|
|
| `sampling/importance_sampling_ratio/min` | > 0.1 | Near 0 indicates stale off-policy data; increase `vllm_sync_interval` |
|
|
| `clip_ratio/region_mean` | < 0.1 | > 0.3 means PPO clipping is too aggressive |
|
|
| `completions/mean_length` | Task-dependent | Monotonically increasing to max length suggests reward hacking |
|
|
| `completions/clipped_ratio` | < 0.3 | > 0.8 means most completions hit `max_completion_length` -- increase it |
|
|
|
|
::: {.callout-note}
|
|
## EBFT-Specific Metrics
|
|
For EBFT training, also monitor `ebft/alignment` (should trend upward, healthy 0.3--0.9), `ebft/diversity` (healthy 0.01--0.1; > 1.0 indicates mode collapse), and `ebft/cfm_loss` (should trend downward, < 10).
|
|
:::
|
|
|
|
## SFT Stability
|
|
|
|
### Loss Plateau
|
|
|
|
**Symptom**: Loss stops decreasing early in training, well above expected values.
|
|
|
|
**Causes and fixes**:
|
|
|
|
- **Learning rate too low**: Increase by 2--5x. Typical ranges: full fine-tune 1e-5 to 5e-5, LoRA 1e-4 to 3e-4.
|
|
- **Insufficient warmup**: Set `warmup_steps` to 5--10% of total steps. Too-aggressive learning at the start can push the model into a flat region.
|
|
- **Data quality**: Check that labels are correctly masked. Use `axolotl preprocess` and inspect tokenized samples to confirm only the target tokens are trainable.
|
|
- **Weight decay too high**: Default 0.01 is usually fine. Values above 0.1 can suppress learning in LoRA.
|
|
|
|
### Loss Spikes
|
|
|
|
**Symptom**: Loss suddenly jumps by 2--10x then (possibly) recovers.
|
|
|
|
**Causes and fixes**:
|
|
|
|
- **Bad data samples**: A single malformed or extremely long example can cause a spike. Enable `sample_packing: false` temporarily and check if spikes correlate with specific batches.
|
|
- **Learning rate too high**: Reduce by 2--5x, or increase warmup.
|
|
- **Gradient accumulation mismatch**: Effective batch size = `micro_batch_size * gradient_accumulation_steps * num_gpus`. Very large effective batch sizes amplify gradient noise.
|
|
- **Mixed precision issues**: With `bf16: true`, some operations can lose precision. If spikes are severe, try `fp32` for diagnosis.
|
|
|
|
### Overfitting
|
|
|
|
**Symptom**: Train loss keeps decreasing but eval loss starts increasing.
|
|
|
|
**Fixes**:
|
|
|
|
- Increase `val_set_size` (e.g., 0.05) and monitor `eval/loss`.
|
|
- Reduce `num_epochs` or `max_steps`.
|
|
- Increase `weight_decay` (try 0.01--0.1).
|
|
- Use a smaller LoRA rank (`lora_r`). Typical values: 8--32.
|
|
- Increase dropout: `lora_dropout: 0.05`.
|
|
|
|
## RL/GRPO Stability
|
|
|
|
### Reward Never Increases
|
|
|
|
If `rewards/*/mean` stays at 0 for more than 20 steps:
|
|
|
|
1. **Test reward function standalone**: Run it outside training with known inputs to verify it returns nonzero values.
|
|
```bash
|
|
cd experiments && python -c "import my_rewards; print(my_rewards.accuracy_reward(...))"
|
|
```
|
|
2. **Check dataset columns**: The reward function receives `**kwargs` containing dataset columns. Verify the columns it needs (e.g., `answer`) are not removed by the dataset transform.
|
|
3. **Check completion content**: Enable `log_completions: true` in the `trl:` config and inspect logged completions in W&B. If completions are empty or incoherent, the model may be too weak for the task.
|
|
4. **Verify vLLM is serving the right model**: Hit the vLLM health endpoint and confirm the model name matches your config.
|
|
|
|
### Entropy Collapse (Mode Collapse)
|
|
|
|
**Symptom**: `entropy` drops below 0.01; all completions become nearly identical.
|
|
|
|
**Fixes**:
|
|
|
|
- Increase `temperature` in generation kwargs (try 0.8--1.0).
|
|
- Reduce learning rate.
|
|
- Add a KL penalty term (`beta` parameter in GRPO config).
|
|
- Check that `num_generations` is sufficient (16+ gives better advantage estimates).
|
|
|
|
### IS Ratio Divergence
|
|
|
|
**Symptom**: `sampling/importance_sampling_ratio/min` drops near 0, or `sampling/sampling_logp_difference/mean` exceeds 1.0.
|
|
|
|
This means the policy has diverged significantly from the weights used by vLLM for generation. The importance sampling correction becomes unreliable.
|
|
|
|
**Fixes**:
|
|
|
|
- Decrease `vllm_sync_interval` (sync weights more often).
|
|
- Enable `off_policy_mask_threshold` (e.g., 0.5) to mask stale off-policy samples.
|
|
- Use `importance_sampling_level: token` for finer-grained correction.
|
|
|
|
### Gradient Norm Instability
|
|
|
|
**Symptom**: `grad_norm` oscillates wildly or exceeds 10.0 regularly.
|
|
|
|
**Fixes**:
|
|
|
|
- Enable gradient clipping: `max_grad_norm: 1.0` (default in most configs).
|
|
- Reduce learning rate.
|
|
- Increase `gradient_accumulation_steps` to smooth out noisy batches.
|
|
- Check for NaN issues (see next section).
|
|
|
|
## NaN and Inf Handling
|
|
|
|
### Common Causes
|
|
|
|
| Cause | Where It Manifests | Detection |
|
|
|-------|-------------------|-----------|
|
|
| FP8 zero-scale division | Forward pass logits | `grad_norm: nan`, loss becomes NaN immediately |
|
|
| Gradient explosion | Backward pass | `grad_norm` spikes to inf, then loss goes NaN |
|
|
| Bad data (empty sequences) | Logprob computation | NaN in specific batches only |
|
|
| Numerical overflow in log-softmax | Loss computation | Large negative logprobs cause exp() overflow |
|
|
|
|
### FP8-Specific NaN Issues
|
|
|
|
FP8 quantization (`fp8: true`) can produce NaN when the activation quantization kernel divides by `max(abs(x)) / 448`. If the input tensor is all zeros (e.g., padding positions), the scale becomes 0, causing division by zero.
|
|
|
|
**Fixes applied in axolotl**:
|
|
|
|
- The `act_quant_kernel` has a zero-guard: `s = tl.where(s == 0, 1.0, s)`.
|
|
- A safety net `nan_to_num(logits, nan=0.0)` is applied in `_get_per_token_logps_and_entropies`.
|
|
- Embedding padding is zero-padded for FP8 compatibility.
|
|
|
|
::: {.callout-important}
|
|
## After Modifying Triton Kernels
|
|
If you patch any Triton JIT kernel (e.g., the FP8 quantization kernels in transformers), you must clear the Triton cache for changes to take effect:
|
|
|
|
```bash
|
|
rm -rf ~/.triton/cache
|
|
```
|
|
:::
|
|
|
|
### General NaN Debugging Steps
|
|
|
|
1. **Enable anomaly detection** (slow, but pinpoints the source):
|
|
```python
|
|
torch.autograd.set_detect_anomaly(True)
|
|
```
|
|
2. **Check grad_norm**: If it goes to NaN, the backward pass is the problem. If loss is NaN but grad_norm was fine on the previous step, the forward pass is the problem.
|
|
3. **Reduce to single GPU, single batch**: Eliminate distributed training variables.
|
|
4. **Inspect data**: Print the batch that triggers NaN. Look for empty sequences, extreme token IDs, or unexpected padding patterns.
|
|
|
|
## OOM Debugging
|
|
|
|
Out-of-memory errors are the most common training failure. Use this systematic approach, from least to most disruptive:
|
|
|
|
### Step 1: Reduce Batch Size
|
|
|
|
The single highest-impact change. VRAM scales roughly linearly with batch size.
|
|
|
|
```yaml
|
|
micro_batch_size: 1 # Start here
|
|
gradient_accumulation_steps: 16 # Increase to maintain effective batch size
|
|
```
|
|
|
|
For GRPO specifically, the logits tensor for policy logprob computation can be very large. `batch_size * num_generations * seq_len * vocab_size` in bf16. For example, with `num_generations: 16` and `micro_batch_size: 8`, the logits tensor alone is:
|
|
|
|
```
|
|
8 * 16 * 2048 * 151936 * 2 bytes = ~75 GB (way too large)
|
|
```
|
|
|
|
Reduce `micro_batch_size` to 2--4 for GRPO.
|
|
|
|
### Step 2: Enable Gradient Checkpointing
|
|
|
|
Trades compute for memory by recomputing activations during the backward pass instead of storing them.
|
|
|
|
```yaml
|
|
gradient_checkpointing: true
|
|
gradient_checkpointing_kwargs:
|
|
use_reentrant: false # Recommended default
|
|
```
|
|
|
|
::: {.callout-warning}
|
|
## Reentrant Checkpointing Exceptions
|
|
Some configurations require `use_reentrant: true`:
|
|
|
|
- DeepSpeed ZeRO-3 (non-reentrant causes `CheckpointError`)
|
|
- EBFT strided mode with flex_attention
|
|
:::
|
|
|
|
### Step 3: Use Quantization
|
|
|
|
Load the base model in reduced precision:
|
|
|
|
```yaml
|
|
# 4-bit QLoRA
|
|
adapter: qlora
|
|
load_in_4bit: true
|
|
|
|
# 8-bit
|
|
load_in_8bit: true
|
|
|
|
# FP8 (saves ~50% model VRAM, same compute speed as bf16)
|
|
fp8: true
|
|
```
|
|
|
|
### Step 4: Reduce Sequence Length
|
|
|
|
```yaml
|
|
sequence_len: 1024 # Down from 2048 or 4096
|
|
```
|
|
|
|
For GRPO, also reduce `max_completion_length`. Memory scales quadratically with sequence length when using standard attention.
|
|
|
|
### Step 5: Use Flash Attention
|
|
|
|
Reduces attention memory from O(n^2) to O(n):
|
|
|
|
```yaml
|
|
flash_attention: true
|
|
```
|
|
|
|
### Step 6: Offload with DeepSpeed
|
|
|
|
For extreme cases, offload optimizer states or parameters to CPU:
|
|
|
|
```yaml
|
|
deepspeed: deepspeed_configs/zero3_bf16.json
|
|
```
|
|
|
|
### Diagnosing the Specific Culprit
|
|
|
|
Use the `profiler_steps` config option to capture GPU memory snapshots:
|
|
|
|
```yaml
|
|
profiler_steps: [1, 2]
|
|
```
|
|
|
|
This generates PyTorch profiler traces you can inspect to see exactly which tensor allocation caused the OOM.
|
|
|
|
## Common Errors
|
|
|
|
| Error Message | Likely Cause | Fix |
|
|
|---------------|-------------|-----|
|
|
| `exitcode: -9` | System RAM exhaustion | Reduce dataset size, `dataset_num_proc`, or number of data workers |
|
|
| `exitcode: -7` (DeepSpeed) | DeepSpeed version issue | `pip install -U deepspeed` |
|
|
| `CUDA out of memory` | GPU VRAM exhaustion | Follow OOM debugging steps above |
|
|
| `RuntimeError: NCCL communicator was aborted` | GPU communication failure | See [NCCL docs](nccl.qmd); check `NCCL_DEBUG=INFO` output |
|
|
| `ValueError: Asking to pad but the tokenizer does not have a padding token` | Missing pad token | Add `special_tokens: { pad_token: "<\|endoftext\|>" }` to config |
|
|
| `'DummyOptim' object has no attribute 'step'` | DeepSpeed on single GPU | Remove `deepspeed:` section from config |
|
|
| `unable to load strategy X` then `None is not callable` | Reward module not importable | Run `cd experiments && python -c "import my_rewards"` to check |
|
|
| `generation_batch_size not divisible by num_generations` | micro_batch_size too small | Set `micro_batch_size >= num_generations` and make it divisible |
|
|
| `'weight' must be 2-D` | FSDP1 flattened parameters | Use `fsdp_version: 2` or skip `unwrap_model` when FSDP is enabled |
|
|
| `CheckpointError` (tensor count mismatch) | Non-reentrant checkpointing + ZeRO-3 or flex_attention | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
|
| `BFloat16` TypeError during weight sync | NumPy does not support bf16 | Fixed in axolotl's `weight_serde.py` (auto bf16 to fp16 conversion) |
|
|
| `Content end boundary is before start boundary` | Chat template parsing issue | Check `eos_token` matches template; file a GitHub issue if persistent |
|
|
| `CAS service error` during data processing | HuggingFace XET issue | Set `export HF_HUB_DISABLE_XET=1` |
|
|
| Training hangs (multi-GPU) | FSDP + async prefetch deadlock | Set `async_prefetch: false` with FSDP |
|
|
|
|
## Profiling
|
|
|
|
### PyTorch Profiler
|
|
|
|
Axolotl supports PyTorch profiler integration via the config:
|
|
|
|
```yaml
|
|
profiler_steps: [1, 2, 3]
|
|
```
|
|
|
|
This captures profiler traces for the specified steps. View them in TensorBoard:
|
|
|
|
```bash
|
|
tensorboard --logdir output_dir/runs
|
|
```
|
|
|
|
Or open the `.json` trace file in `chrome://tracing`.
|
|
|
|
### CUDA Memory Snapshots
|
|
|
|
For detailed memory analysis, use PyTorch's memory snapshot API. Add this to your training script or use it interactively:
|
|
|
|
```python
|
|
import torch
|
|
|
|
# Enable memory history tracking
|
|
torch.cuda.memory._record_memory_history()
|
|
|
|
# ... run your training step ...
|
|
|
|
# Save snapshot
|
|
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
|
|
```
|
|
|
|
Visualize with PyTorch's memory visualizer:
|
|
|
|
```bash
|
|
python -m torch.cuda.memory._viz memory_snapshot.pickle
|
|
```
|
|
|
|
### Quick GPU Memory Check
|
|
|
|
During training, monitor GPU utilization in a separate terminal:
|
|
|
|
```bash
|
|
watch -n 1 nvidia-smi
|
|
```
|
|
|
|
For programmatic access within axolotl, the logged metrics `memory/max_alloc` and `memory/max_reserved` come from `torch.cuda.max_memory_allocated()` and `torch.cuda.max_memory_reserved()`. Note these report PyTorch's view of memory, which may differ from `nvidia-smi` (see [FAQ](faq.qmd)).
|
|
|
|
## W&B and Logging
|
|
|
|
### Enabling Logging
|
|
|
|
```yaml
|
|
wandb_project: my-project
|
|
wandb_entity: my-team # optional
|
|
wandb_run_id: run-123 # optional, for resuming
|
|
wandb_name: experiment-name # optional
|
|
logging_steps: 1 # log every step (recommended for RL)
|
|
```
|
|
|
|
### Debug Logging
|
|
|
|
For detailed axolotl-internal debug output:
|
|
|
|
```bash
|
|
AXOLOTL_LOG_LEVEL=DEBUG axolotl train config.yaml 2>&1 | tee /tmp/training.log
|
|
```
|
|
|
|
::: {.callout-tip}
|
|
## Always Log to a File
|
|
Pipe training output to a log file so you can inspect it after the run:
|
|
|
|
```bash
|
|
axolotl train config.yaml 2>&1 | tee /tmp/my_run.log
|
|
```
|
|
:::
|
|
|
|
### What Axolotl Logs
|
|
|
|
**SFT metrics** (logged every `logging_steps`):
|
|
|
|
- `train/loss`, `eval/loss` -- training and validation loss
|
|
- `train/grad_norm` -- gradient L2 norm (before clipping)
|
|
- `train/learning_rate` -- current learning rate
|
|
- `memory/max_alloc`, `memory/max_reserved` -- peak GPU memory
|
|
|
|
**GRPO/RL metrics** (logged every step):
|
|
|
|
- `rewards/<name>/mean`, `rewards/<name>/std` -- per-reward-function statistics
|
|
- `reward`, `reward_std` -- aggregated reward across all reward functions
|
|
- `frac_reward_zero_std` -- fraction of prompt groups where all completions got the same reward
|
|
- `completions/mean_length`, `completions/min_length`, `completions/max_length` -- completion token lengths
|
|
- `completions/clipped_ratio` -- fraction of completions that hit the max length
|
|
- `completions/mean_terminated_length`, `completions/min_terminated_length`, `completions/max_terminated_length` -- lengths of naturally terminated completions
|
|
- `kl` -- KL divergence between policy and reference
|
|
- `entropy` -- policy entropy (measure of output diversity)
|
|
- `clip_ratio/region_mean`, `clip_ratio/low_mean`, `clip_ratio/high_mean` -- PPO clipping statistics
|
|
- `sampling/sampling_logp_difference/mean`, `sampling/sampling_logp_difference/max` -- log-probability difference between policy and sampling distribution
|
|
- `sampling/importance_sampling_ratio/min`, `sampling/importance_sampling_ratio/mean`, `sampling/importance_sampling_ratio/max` -- IS ratio statistics for off-policy correction
|
|
- `num_tokens` -- total tokens processed
|
|
|
|
### Reading W&B Charts
|
|
|
|
For a healthy GRPO run, expect to see:
|
|
|
|
1. **`reward/mean`**: Gradual upward trend. May start near 0 and reach 0.3--0.8 depending on task difficulty. Not monotonic -- fluctuations are normal.
|
|
2. **`entropy`**: Gradual decrease from initial values (often 0.3--0.6) as the model becomes more confident. Should not collapse to near-zero.
|
|
3. **`grad_norm`**: Mostly in the 0.001--1.0 range. Occasional 0.0 values are fine (zero-advantage skip). Persistent values above 10.0 need investigation.
|
|
4. **`kl`**: Starts near 0 and grows slowly. If it shoots up rapidly, the policy is diverging from the reference.
|
|
5. **`completions/mean_length`**: Should reflect the task's natural answer length. If it steadily increases to `max_completion_length`, the model may be reward-hacking by generating longer outputs.
|