Nemo gym integration (#3516) [skip ci]
* nemo gym integration with grpo wip * mostly working * cleanup * simplify * update docs * nemo gym support wip * cleanup * chore: lint * address PR review and add more tests * chore: lint * post merge lora fixes for CI (#3536) [skip ci] * post merge lora fixes for CI * handle lora kernel auto-enable for moe without grouped_mm * prefer not to import torch in schema validation * address pr comments, add timeout, add tests * roundup_power2_divisions not needed with newer pytorch versions (#3540) * roundup_power2_divisions not needed with newer pytorch versions * remove typo * update qwen3.5 moe 35b-a3b yaml for 5090 * more bug fixes * fix tests to match updated trainer * don't use fa2 for hooks test * reset plugins on the instance * retry download * fix references to renamed axolotl_cfg property on trainer * Fix ref to trainer cfg * fix: robust handling of race condition on patching check (#3543) [skip ci] * EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci] * EBFT wip * fixes * more fixeS * add missing strided module * ebft fixes for multi-turn * make ebft work with async * add example for ebft w qwen3.5 * fix for split thinking and update yaml for lora over linear attention only * enforce_eager for vllm arg in schema * fix sync weights * fix multi-gpu * handle updated sig for mm * ddp fixes * improve multi-gpu handling, don't calculate logits, adaptive completion length * chore: lint * chore: lint * support completion_mean * Address corereview feedback * clamp min IS ratio * Address PR code review * more fixes identified * address code review * Fix property from rebase conflict * fix for ebft sync and update docs * make trainer loss patch check a solo test --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
298
docs/rlhf.qmd
298
docs/rlhf.qmd
@@ -18,6 +18,8 @@ feedback. Various methods include, but not limited to:
|
||||
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
||||
- [Group Relative Policy Optimization (GRPO)](#grpo)
|
||||
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
|
||||
- [Energy-Based Fine-Tuning (EBFT)](#ebft)
|
||||
- [NeMo Gym Integration](#nemo-gym-integration)
|
||||
|
||||
|
||||
## RLHF using Axolotl
|
||||
@@ -1037,6 +1039,302 @@ simpo_gamma: 0.5 # default in CPOTrainer
|
||||
|
||||
This method uses the same dataset format as [DPO](#dpo).
|
||||
|
||||
### EBFT
|
||||
|
||||
EBFT (Energy-Based Fine-Tuning) fine-tunes language models by optimizing a **feature-matching loss** rather than relying on external reward functions. A frozen copy of the model extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.
|
||||
|
||||
Paper: ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026)
|
||||
|
||||
**Key advantages:**
|
||||
|
||||
- No reward model or verifier required — works on any (prompt, completion) data
|
||||
- Applicable to non-verifiable tasks (code, translation, creative writing)
|
||||
- Operates on model rollouts (not teacher forcing), reducing distribution shift
|
||||
|
||||
EBFT supports two modes:
|
||||
|
||||
- **Structured mode**: For QA/instruction data with prompt + completion pairs. Uses vLLM for generation (like GRPO).
|
||||
- **Strided mode**: For unstructured text without prompt/completion splits. Uses strided block-parallel generation with flex_attention — no vLLM needed.
|
||||
|
||||
#### Structured Mode
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3-4B
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75] # Extract features at 25%, 50%, 75% depth
|
||||
embed_method: last_token
|
||||
use_whitening: false
|
||||
alignment_coef: 1.0 # Cosine similarity reward weight
|
||||
diversity_coef: 1.0 # Pairwise dot product penalty
|
||||
ce_coef: 0.0 # Cross-entropy on GT tokens (0 = off)
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 256
|
||||
temperature: 0.7
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
vllm_lora_sync: true # LoRA adapter sync (recommended)
|
||||
vllm_sync_interval: 3
|
||||
use_data_producer: true
|
||||
async_prefetch: true # Set false for sync mode
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.5
|
||||
max_model_len: 2048
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_linear: true
|
||||
```
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# Terminal 2: Train
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
#### Strided Mode
|
||||
|
||||
For unstructured text (raw code, prose). No vLLM needed — runs on a single GPU.
|
||||
|
||||
```yaml
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided
|
||||
stride: 8
|
||||
context_length: 8
|
||||
generate_max_len: 8
|
||||
n_samples_per_prompt: 4
|
||||
temperature: 0.6
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0
|
||||
ce_coef: 0.03
|
||||
advantage_estimator: rloo
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_strided_structured.transform
|
||||
split: train[:1%]
|
||||
|
||||
flash_attention: false
|
||||
flex_attention: true # Strided mode uses flex_attention
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # Required for flex_attention
|
||||
```
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
See `examples/ebft/` for complete example configs covering Llama 1B/3B/8B and Qwen3 4B/8B models in both modes.
|
||||
:::
|
||||
|
||||
#### EBFT Configuration Reference
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `ebft.feature_layers` | `[0.25, 0.5, 0.75]` | Layer depths for feature extraction (fractional) |
|
||||
| `ebft.embed_method` | `last_token` | Feature pooling: `last_token`, `mean_pooling`, `concat` |
|
||||
| `ebft.use_whitening` | `false` | SVD whitening of feature dimensions |
|
||||
| `ebft.alignment_coef` | `1.0` | Cosine similarity reward weight |
|
||||
| `ebft.diversity_coef` | `1.0` | Pairwise dot product penalty weight |
|
||||
| `ebft.ce_coef` | `0.0` | Cross-entropy loss on ground-truth tokens |
|
||||
| `ebft.mode` | `structured` | `structured` (vLLM) or `strided` (no vLLM) |
|
||||
| `ebft.stride` | — | Tokens between anchor points (strided mode) |
|
||||
| `ebft.context_length` | — | Context window per block (strided mode) |
|
||||
| `ebft.generate_max_len` | — | Tokens to generate per block (strided mode) |
|
||||
| `ebft.n_samples_per_prompt` | — | Rollouts per document (strided mode) |
|
||||
| `ebft.advantage_estimator` | `grpo` | `grpo` or `rloo` (strided mode) |
|
||||
|
||||
### NeMo Gym Integration
|
||||
|
||||
[NeMo Gym](https://github.com/NVIDIA-NeMo/Gym) provides 50+ verified RL environments (math, coding, tool-use, reasoning) with deterministic reward signals. The axolotl integration supports both **single-turn** (call `/verify` after generation) and **multi-turn** (agent-based tool execution via `/run`).
|
||||
|
||||
#### Single-Turn (Simplest)
|
||||
|
||||
For environments that only need answer verification (math, coding challenges). No agent server needed — the reward function calls `/verify` directly on the resource server.
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-0.5B-Instruct
|
||||
|
||||
rl: grpo
|
||||
chat_template: tokenizer_default
|
||||
|
||||
trl:
|
||||
use_vllm: false # Colocate mode (single GPU)
|
||||
num_generations: 4
|
||||
max_completion_length: 128
|
||||
temperature: 0.9
|
||||
reward_funcs:
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.nemo_gym.NemoGymPlugin
|
||||
|
||||
nemo_gym_enabled: true
|
||||
nemo_gym_dir: ~/Gym
|
||||
nemo_gym_auto_start: false
|
||||
nemo_gym_head_port: 11000
|
||||
nemo_gym_datasets:
|
||||
- path: resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
|
||||
server_name: reasoning_gym
|
||||
|
||||
datasets:
|
||||
- path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
|
||||
type: chat_template
|
||||
field_messages: responses_create_params.input
|
||||
message_field_content: content
|
||||
message_field_role: role
|
||||
```
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start NeMo Gym resource server
|
||||
cd ~/Gym && .venv/bin/ng_run \
|
||||
"+config_paths=[resources_servers/reasoning_gym/configs/resources_only.yaml]" \
|
||||
"+skip_venv_if_present=true"
|
||||
|
||||
# Terminal 2: Train
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
`nemo_gym_datasets.path` is relative to `nemo_gym_dir`. Don't use absolute paths or they will be double-joined.
|
||||
:::
|
||||
|
||||
#### Multi-Turn with Async GRPO (Recommended)
|
||||
|
||||
For environments with tool-use (weather, search, databases). An agent server orchestrates multi-turn interactions: generate → parse tool calls → execute tools → feed results back → repeat until done.
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3-0.6B
|
||||
|
||||
rl: grpo
|
||||
chat_template: tokenizer_default
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
|
||||
|
||||
trl:
|
||||
use_vllm: true
|
||||
vllm_mode: server
|
||||
vllm_server_host: localhost
|
||||
vllm_server_port: 8000
|
||||
vllm_lora_sync: true
|
||||
vllm_sync_interval: 5
|
||||
use_data_producer: true
|
||||
async_prefetch: true # 3x speedup
|
||||
num_generations: 4
|
||||
max_completion_length: 512
|
||||
temperature: 0.8
|
||||
reward_funcs:
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_env
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.nemo_gym.NemoGymPlugin
|
||||
|
||||
nemo_gym_enabled: true
|
||||
nemo_gym_auto_start: false
|
||||
nemo_gym_head_port: 11000
|
||||
nemo_gym_multi_turn: true
|
||||
nemo_gym_verify_timeout: 120
|
||||
nemo_gym_datasets:
|
||||
- path: resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
|
||||
server_name: example_single_tool_call
|
||||
|
||||
datasets:
|
||||
- path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
|
||||
type: chat_template
|
||||
field_messages: responses_create_params.input
|
||||
message_field_content: content
|
||||
message_field_role: role
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.85
|
||||
max_model_len: 2048
|
||||
```
|
||||
|
||||
Multi-turn requires three services running:
|
||||
|
||||
```bash
|
||||
# Terminal 1: vLLM with LoRA + tool calling
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 CUDA_VISIBLE_DEVICES=0 \
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen3-0.6B --max-model-len 2048 \
|
||||
--gpu-memory-utilization 0.85 \
|
||||
--enable-lora --max-lora-rank 64 \
|
||||
--enable-auto-tool-choice --tool-call-parser hermes
|
||||
|
||||
# Terminal 2: NeMo Gym servers (resource + model proxy + agent)
|
||||
cd ~/Gym && .venv/bin/ng_run \
|
||||
"+config_paths=[configs/axolotl_tool_calling.yaml]" \
|
||||
"+skip_venv_if_present=true"
|
||||
|
||||
# Terminal 3: Training
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
Multi-turn requires a NeMo Gym agent config YAML that defines three components: a resource server (tools + `/verify`), a model server proxy (forwards to your vLLM), and an agent server (orchestrates `/run`). See the [NeMo Gym README](https://github.com/NVIDIA-NeMo/Gym) for agent config format.
|
||||
:::
|
||||
|
||||
#### NeMo Gym Prerequisites
|
||||
|
||||
```bash
|
||||
# Clone and set up NeMo Gym
|
||||
git clone https://github.com/NVIDIA-NeMo/Gym.git ~/Gym
|
||||
cd ~/Gym
|
||||
uv venv --python 3.12 && source .venv/bin/activate && uv sync
|
||||
|
||||
# Fix pycosat build (GCC 13+)
|
||||
CFLAGS="" uv pip install pycosat --python .venv/bin/python --no-build-isolation
|
||||
```
|
||||
|
||||
#### NeMo Gym Configuration Reference
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `nemo_gym_enabled` | bool | — | Enable the NeMo Gym integration |
|
||||
| `nemo_gym_dir` | str | `~/Gym` | Path to NeMo Gym repo |
|
||||
| `nemo_gym_auto_start` | bool | `true` | Auto-start resource servers |
|
||||
| `nemo_gym_head_port` | int | `11000` | Head server port |
|
||||
| `nemo_gym_multi_turn` | bool | `false` | Enable multi-turn via agent `/run` |
|
||||
| `nemo_gym_verify_timeout` | int | `30` | Per-request timeout (seconds) |
|
||||
| `nemo_gym_datasets` | list | required | Dataset configs with `path` and `server_name` |
|
||||
|
||||
#### Reward Functions
|
||||
|
||||
| Function | Mode | Description |
|
||||
|----------|------|-------------|
|
||||
| `axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify` | Single-turn | Calls `/verify`, returns binary reward |
|
||||
| `axolotl.integrations.nemo_gym.rewards.reward_env` | Multi-turn | Passthrough reward from agent `/run` |
|
||||
|
||||
### Using local dataset files
|
||||
|
||||
```yaml
|
||||
|
||||
Reference in New Issue
Block a user