Nemo gym integration (#3516) [skip ci]

* nemo gym integration with grpo wip

* mostly working

* cleanup

* simplify

* update docs

* nemo gym support wip

* cleanup

* chore: lint

* address PR review and add more tests

* chore: lint

* post merge lora fixes for CI (#3536) [skip ci]

* post merge lora fixes for CI

* handle lora kernel auto-enable for moe without grouped_mm

* prefer not to import torch in schema validation

* address pr comments, add timeout, add tests

* roundup_power2_divisions not needed with newer pytorch versions (#3540)

* roundup_power2_divisions not needed with newer pytorch versions

* remove typo

* update qwen3.5 moe 35b-a3b yaml for 5090

* more bug fixes

* fix tests to match updated trainer

* don't use fa2 for hooks test

* reset plugins on the instance

* retry download

* fix references to renamed axolotl_cfg property on trainer

* Fix ref to trainer cfg

* fix: robust handling of race condition on patching check (#3543) [skip ci]

* EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]

* EBFT wip

* fixes

* more fixeS

* add missing strided module

* ebft fixes for multi-turn

* make ebft work with async

* add example for ebft w qwen3.5

* fix for split thinking and update yaml for lora over linear attention only

* enforce_eager for vllm arg in schema

* fix sync weights

* fix multi-gpu

* handle updated sig for mm

* ddp fixes

* improve multi-gpu handling, don't calculate logits, adaptive completion length

* chore: lint

* chore: lint

* support completion_mean

* Address corereview feedback

* clamp min IS ratio

* Address PR code review

* more fixes identified

* address code review

* Fix property from rebase conflict

* fix for ebft sync and update docs

* make trainer loss patch check a solo test

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Wing Lian
2026-03-25 07:38:06 -04:00
committed by GitHub
parent 2fb72798e0
commit c2bd75aff6
20 changed files with 3592 additions and 19 deletions

View File

@@ -18,6 +18,8 @@ feedback. Various methods include, but not limited to:
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- [Group Relative Policy Optimization (GRPO)](#grpo)
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
- [Energy-Based Fine-Tuning (EBFT)](#ebft)
- [NeMo Gym Integration](#nemo-gym-integration)
## RLHF using Axolotl
@@ -1037,6 +1039,302 @@ simpo_gamma: 0.5 # default in CPOTrainer
This method uses the same dataset format as [DPO](#dpo).
### EBFT
EBFT (Energy-Based Fine-Tuning) fine-tunes language models by optimizing a **feature-matching loss** rather than relying on external reward functions. A frozen copy of the model extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.
Paper: ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026)
**Key advantages:**
- No reward model or verifier required — works on any (prompt, completion) data
- Applicable to non-verifiable tasks (code, translation, creative writing)
- Operates on model rollouts (not teacher forcing), reducing distribution shift
EBFT supports two modes:
- **Structured mode**: For QA/instruction data with prompt + completion pairs. Uses vLLM for generation (like GRPO).
- **Strided mode**: For unstructured text without prompt/completion splits. Uses strided block-parallel generation with flex_attention — no vLLM needed.
#### Structured Mode
```yaml
base_model: Qwen/Qwen3-4B
rl: ebft
ebft:
feature_layers: [0.25, 0.5, 0.75] # Extract features at 25%, 50%, 75% depth
embed_method: last_token
use_whitening: false
alignment_coef: 1.0 # Cosine similarity reward weight
diversity_coef: 1.0 # Pairwise dot product penalty
ce_coef: 0.0 # Cross-entropy on GT tokens (0 = off)
trl:
num_generations: 4
max_completion_length: 256
temperature: 0.7
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_lora_sync: true # LoRA adapter sync (recommended)
vllm_sync_interval: 3
use_data_producer: true
async_prefetch: true # Set false for sync mode
scale_rewards: true
loss_type: grpo
epsilon: 0.2
vllm:
gpu_memory_utilization: 0.5
max_model_len: 2048
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:500]
adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_linear: true
```
```bash
# Terminal 1: Start vLLM
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Terminal 2: Train
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
#### Strided Mode
For unstructured text (raw code, prose). No vLLM needed — runs on a single GPU.
```yaml
base_model: meta-llama/Llama-3.2-1B
rl: ebft
ebft:
mode: strided
stride: 8
context_length: 8
generate_max_len: 8
n_samples_per_prompt: 4
temperature: 0.6
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: true
alignment_coef: 1.0
diversity_coef: 1.0
rl_coef: 1.0
ce_coef: 0.03
advantage_estimator: rloo
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_strided_structured.transform
split: train[:1%]
flash_attention: false
flex_attention: true # Strided mode uses flex_attention
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # Required for flex_attention
```
```bash
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
```
::: {.callout-tip}
See `examples/ebft/` for complete example configs covering Llama 1B/3B/8B and Qwen3 4B/8B models in both modes.
:::
#### EBFT Configuration Reference
| Parameter | Default | Description |
|-----------|---------|-------------|
| `ebft.feature_layers` | `[0.25, 0.5, 0.75]` | Layer depths for feature extraction (fractional) |
| `ebft.embed_method` | `last_token` | Feature pooling: `last_token`, `mean_pooling`, `concat` |
| `ebft.use_whitening` | `false` | SVD whitening of feature dimensions |
| `ebft.alignment_coef` | `1.0` | Cosine similarity reward weight |
| `ebft.diversity_coef` | `1.0` | Pairwise dot product penalty weight |
| `ebft.ce_coef` | `0.0` | Cross-entropy loss on ground-truth tokens |
| `ebft.mode` | `structured` | `structured` (vLLM) or `strided` (no vLLM) |
| `ebft.stride` | — | Tokens between anchor points (strided mode) |
| `ebft.context_length` | — | Context window per block (strided mode) |
| `ebft.generate_max_len` | — | Tokens to generate per block (strided mode) |
| `ebft.n_samples_per_prompt` | — | Rollouts per document (strided mode) |
| `ebft.advantage_estimator` | `grpo` | `grpo` or `rloo` (strided mode) |
### NeMo Gym Integration
[NeMo Gym](https://github.com/NVIDIA-NeMo/Gym) provides 50+ verified RL environments (math, coding, tool-use, reasoning) with deterministic reward signals. The axolotl integration supports both **single-turn** (call `/verify` after generation) and **multi-turn** (agent-based tool execution via `/run`).
#### Single-Turn (Simplest)
For environments that only need answer verification (math, coding challenges). No agent server needed — the reward function calls `/verify` directly on the resource server.
```yaml
base_model: Qwen/Qwen2.5-0.5B-Instruct
rl: grpo
chat_template: tokenizer_default
trl:
use_vllm: false # Colocate mode (single GPU)
num_generations: 4
max_completion_length: 128
temperature: 0.9
reward_funcs:
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
plugins:
- axolotl.integrations.nemo_gym.NemoGymPlugin
nemo_gym_enabled: true
nemo_gym_dir: ~/Gym
nemo_gym_auto_start: false
nemo_gym_head_port: 11000
nemo_gym_datasets:
- path: resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
server_name: reasoning_gym
datasets:
- path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
type: chat_template
field_messages: responses_create_params.input
message_field_content: content
message_field_role: role
```
```bash
# Terminal 1: Start NeMo Gym resource server
cd ~/Gym && .venv/bin/ng_run \
"+config_paths=[resources_servers/reasoning_gym/configs/resources_only.yaml]" \
"+skip_venv_if_present=true"
# Terminal 2: Train
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
```
::: {.callout-note}
`nemo_gym_datasets.path` is relative to `nemo_gym_dir`. Don't use absolute paths or they will be double-joined.
:::
#### Multi-Turn with Async GRPO (Recommended)
For environments with tool-use (weather, search, databases). An agent server orchestrates multi-turn interactions: generate → parse tool calls → execute tools → feed results back → repeat until done.
```yaml
base_model: Qwen/Qwen3-0.6B
rl: grpo
chat_template: tokenizer_default
adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
trl:
use_vllm: true
vllm_mode: server
vllm_server_host: localhost
vllm_server_port: 8000
vllm_lora_sync: true
vllm_sync_interval: 5
use_data_producer: true
async_prefetch: true # 3x speedup
num_generations: 4
max_completion_length: 512
temperature: 0.8
reward_funcs:
- axolotl.integrations.nemo_gym.rewards.reward_env
plugins:
- axolotl.integrations.nemo_gym.NemoGymPlugin
nemo_gym_enabled: true
nemo_gym_auto_start: false
nemo_gym_head_port: 11000
nemo_gym_multi_turn: true
nemo_gym_verify_timeout: 120
nemo_gym_datasets:
- path: resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
server_name: example_single_tool_call
datasets:
- path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
type: chat_template
field_messages: responses_create_params.input
message_field_content: content
message_field_role: role
vllm:
gpu_memory_utilization: 0.85
max_model_len: 2048
```
Multi-turn requires three services running:
```bash
# Terminal 1: vLLM with LoRA + tool calling
VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 CUDA_VISIBLE_DEVICES=0 \
python -m vllm.entrypoints.openai.api_server \
--model Qwen/Qwen3-0.6B --max-model-len 2048 \
--gpu-memory-utilization 0.85 \
--enable-lora --max-lora-rank 64 \
--enable-auto-tool-choice --tool-call-parser hermes
# Terminal 2: NeMo Gym servers (resource + model proxy + agent)
cd ~/Gym && .venv/bin/ng_run \
"+config_paths=[configs/axolotl_tool_calling.yaml]" \
"+skip_venv_if_present=true"
# Terminal 3: Training
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
::: {.callout-important}
Multi-turn requires a NeMo Gym agent config YAML that defines three components: a resource server (tools + `/verify`), a model server proxy (forwards to your vLLM), and an agent server (orchestrates `/run`). See the [NeMo Gym README](https://github.com/NVIDIA-NeMo/Gym) for agent config format.
:::
#### NeMo Gym Prerequisites
```bash
# Clone and set up NeMo Gym
git clone https://github.com/NVIDIA-NeMo/Gym.git ~/Gym
cd ~/Gym
uv venv --python 3.12 && source .venv/bin/activate && uv sync
# Fix pycosat build (GCC 13+)
CFLAGS="" uv pip install pycosat --python .venv/bin/python --no-build-isolation
```
#### NeMo Gym Configuration Reference
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `nemo_gym_enabled` | bool | — | Enable the NeMo Gym integration |
| `nemo_gym_dir` | str | `~/Gym` | Path to NeMo Gym repo |
| `nemo_gym_auto_start` | bool | `true` | Auto-start resource servers |
| `nemo_gym_head_port` | int | `11000` | Head server port |
| `nemo_gym_multi_turn` | bool | `false` | Enable multi-turn via agent `/run` |
| `nemo_gym_verify_timeout` | int | `30` | Per-request timeout (seconds) |
| `nemo_gym_datasets` | list | required | Dataset configs with `path` and `server_name` |
#### Reward Functions
| Function | Mode | Description |
|----------|------|-------------|
| `axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify` | Single-turn | Calls `/verify`, returns binary reward |
| `axolotl.integrations.nemo_gym.rewards.reward_env` | Multi-turn | Passthrough reward from agent `/run` |
### Using local dataset files
```yaml