Nemo gym integration (#3516) [skip ci]

* nemo gym integration with grpo wip

* mostly working

* cleanup

* simplify

* update docs

* nemo gym support wip

* cleanup

* chore: lint

* address PR review and add more tests

* chore: lint

* post merge lora fixes for CI (#3536) [skip ci]

* post merge lora fixes for CI

* handle lora kernel auto-enable for moe without grouped_mm

* prefer not to import torch in schema validation

* address pr comments, add timeout, add tests

* roundup_power2_divisions not needed with newer pytorch versions (#3540)

* roundup_power2_divisions not needed with newer pytorch versions

* remove typo

* update qwen3.5 moe 35b-a3b yaml for 5090

* more bug fixes

* fix tests to match updated trainer

* don't use fa2 for hooks test

* reset plugins on the instance

* retry download

* fix references to renamed axolotl_cfg property on trainer

* Fix ref to trainer cfg

* fix: robust handling of race condition on patching check (#3543) [skip ci]

* EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]

* EBFT wip

* fixes

* more fixeS

* add missing strided module

* ebft fixes for multi-turn

* make ebft work with async

* add example for ebft w qwen3.5

* fix for split thinking and update yaml for lora over linear attention only

* enforce_eager for vllm arg in schema

* fix sync weights

* fix multi-gpu

* handle updated sig for mm

* ddp fixes

* improve multi-gpu handling, don't calculate logits, adaptive completion length

* chore: lint

* chore: lint

* support completion_mean

* Address corereview feedback

* clamp min IS ratio

* Address PR code review

* more fixes identified

* address code review

* Fix property from rebase conflict

* fix for ebft sync and update docs

* make trainer loss patch check a solo test

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Wing Lian
2026-03-25 07:38:06 -04:00
committed by GitHub
parent 2fb72798e0
commit c2bd75aff6
20 changed files with 3592 additions and 19 deletions

View File

@@ -176,6 +176,22 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
)
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
if not async_grpo:
# Filter out async/fast-async-only fields not in standard GRPOConfig.
# These are defined in FastAsyncGRPOConfig and only used by
# AxolotlAsyncGRPOConfig. Standard GRPOConfig rejects them.
import dataclasses
from trl import GRPOConfig as _BaseGRPOConfig
from axolotl.core.trainers.grpo.fast_async_trainer import (
FastAsyncGRPOConfig,
)
async_only_fields = {
f.name for f in dataclasses.fields(FastAsyncGRPOConfig)
} - {f.name for f in dataclasses.fields(_BaseGRPOConfig)}
blocklist_args_kwargs.extend(list(async_only_fields))
if self.cfg.rl is RLType.GDPO:
training_args_kwargs.setdefault(
"multi_objective_aggregation", "normalize_then_sum"

View File

@@ -34,7 +34,16 @@ class EBFTStrategy:
return AxolotlStridedEBFTTrainer
# Structured mode: async or sync
use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False)
# use_data_producer also triggers async trainer (needed for LoRA sync
# without async_prefetch, since sync trainer lacks LoRA sync support)
use_async = (
cfg
and cfg.trl
and (
getattr(cfg.trl, "async_prefetch", False)
or getattr(cfg.trl, "use_data_producer", False)
)
)
if use_async:
from axolotl.core.trainers.ebft.trainer import AxolotlAsyncEBFTTrainer
@@ -50,7 +59,14 @@ class EBFTStrategy:
return AxolotlStridedEBFTConfig
# Structured mode: async or sync config
use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False)
use_async = (
cfg
and cfg.trl
and (
getattr(cfg.trl, "async_prefetch", False)
or getattr(cfg.trl, "use_data_producer", False)
)
)
if use_async:
return AxolotlAsyncEBFTConfig
return AxolotlEBFTConfig

View File

@@ -1012,27 +1012,44 @@ class AsyncGRPOTrainer(GRPOTrainer):
import requests
vllm_client = self.vllm_generation.vllm_client
url = f"{vllm_client.base_url}/set_lora_adapter/"
base_url = vllm_client.base_url
base_model = getattr(self.args, "model_name_or_path", "axolotl-lora")
sync_timeout = getattr(self.args, "vllm_server_timeout", 300) or 300
# Try standard vLLM /v1/load_lora_adapter first, fall back to custom endpoint
response = requests.post(
url,
f"{base_url}/v1/load_lora_adapter",
json={
"lora_name": "active_lora",
"lora_int_id": self._lora_sync_version,
"lora_name": base_model,
"lora_path": adapter_path,
"load_inplace": True,
},
timeout=sync_timeout,
)
if response.status_code != 200:
logger.warning(
"Failed to set LoRA adapter: %s %s",
response.status_code,
response.text,
# Fallback: try custom /set_lora_adapter/ endpoint
response = requests.post(
f"{base_url}/set_lora_adapter/",
json={
"lora_name": "active_lora",
"lora_int_id": self._lora_sync_version,
"lora_path": adapter_path,
},
timeout=30,
)
return
if response.status_code != 200:
logger.warning(
"Failed to set LoRA adapter: %s %s",
response.status_code,
response.text,
)
return
# Reset prefix cache after adapter update
vllm_client.reset_prefix_cache()
try:
vllm_client.reset_prefix_cache()
except Exception as exc:
logger.warning("Failed to reset prefix cache: %s", exc)
# Clean up old adapter versions (keep only current)
if self._lora_sync_version > 1:
@@ -2486,6 +2503,9 @@ class AsyncGRPOTrainer(GRPOTrainer):
logits, completion_ids, self.temperature
)
all_logps.append(logps)
# Liger fused path doesn't compute entropy — append zeros
if compute_entropy:
all_entropies.append(torch.zeros_like(logps))
else:
logits = logits[:, :-1, :]
logits = logits[:, -logits_to_keep:, :]

View File

@@ -0,0 +1,412 @@
# NeMo Gym Integration for Axolotl
Train LLMs with reinforcement learning using [NVIDIA NeMo Gym](https://github.com/NVIDIA-NeMo/Gym) environments as reward sources. NeMo Gym provides 50+ verified RL environments spanning math, coding, tool-use, reasoning, and safety — each with deterministic reward signals.
## Validated Training Paths
| Path | Speed | Multi-turn | Architecture |
|------|-------|------------|--------------|
| **Async GRPO + Data Producer** | Fastest (3x) | Yes | `NemoGymDataProducer` replaces vLLM generation |
| Standard GRPO + Data Producer | Baseline | Yes | Same producer, no async prefetch |
| Standard GRPO + /verify | Simplest | No | Reward function calls /verify directly |
| FSDP2 + /verify (2 GPU) | Distributed | No | `fsdp_version: 2` |
Multi-turn uses `nemo_gym_multi_turn: true` which auto-enables the async trainer's
data producer protocol. The plugin's `NemoGymDataProducer` calls NeMo Gym agent `/run`
endpoints and returns `RolloutDataset` with proper IS correction, env_mask, and rewards.
All paths tested end-to-end with Qwen3-0.6B + LoRA, logged to wandb project `nemo-gym-rl`.
## Quick Start
### Prerequisites
- [uv](https://github.com/astral-sh/uv) package manager (for NeMo Gym's venv)
- Two GPUs recommended (one for vLLM server, one for training)
### 1. Set Up NeMo Gym
```bash
git clone https://github.com/NVIDIA-NeMo/Gym.git ~/Gym
cd ~/Gym
uv venv --python 3.12 && source .venv/bin/activate && uv sync
# Fix pycosat build (GCC 13+)
CFLAGS="" uv pip install pycosat --python .venv/bin/python --no-build-isolation
# Pre-build resource server venvs
for dir in resources_servers/reasoning_gym resources_servers/example_single_tool_call responses_api_models/vllm_model responses_api_agents/simple_agent; do
uv venv --seed --allow-existing --python 3.12 $dir/.venv
CFLAGS="" uv pip install --python $dir/.venv/bin/python pycosat --no-build-isolation 2>/dev/null
uv pip install --python $dir/.venv/bin/python -e . "ray[default]==2.52.1"
done
# Install extra deps for reasoning_gym
uv pip install --python resources_servers/reasoning_gym/.venv/bin/python \
reasoning-gym matplotlib pillow cycler contourpy kiwisolver
```
### 2. Multi-Turn with Async GRPO (Recommended — Fastest Path)
This is the fully validated, highest-performance path. NeMo Gym's agent server handles
multi-turn tool execution while axolotl's async GRPO prefetches data in background threads.
**Step 1: Create the NeMo Gym agent config**
Create `~/Gym/configs/axolotl_tool_calling.yaml`:
```yaml
# Resource server (tools + verify)
example_single_tool_call:
resources_servers:
example_single_tool_call:
entrypoint: app.py
domain: agent
verified: false
# Model server proxy (forwards to your vLLM)
policy_model:
responses_api_models:
vllm_model:
entrypoint: app.py
base_url: http://localhost:8000/v1
api_key: dummy_key
model: Qwen/Qwen3-0.6B # Must match your training model
return_token_id_information: true
uses_reasoning_parser: false
# Agent server (orchestrates multi-turn via /run)
example_single_tool_call_simple_agent:
responses_api_agents:
simple_agent:
entrypoint: app.py
resources_server:
type: resources_servers
name: example_single_tool_call
model_server:
type: responses_api_models
name: policy_model
datasets:
- name: weather
type: example
jsonl_fpath: resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
```
**Step 2: Start three services**
```bash
# Terminal 1: vLLM OpenAI server on GPU 0
CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \
--model Qwen/Qwen3-0.6B --max-model-len 2048 --gpu-memory-utilization 0.85
# Terminal 2: NeMo Gym (resource server + model proxy + agent)
cd ~/Gym && .venv/bin/ng_run \
"+config_paths=[configs/axolotl_tool_calling.yaml]" "+skip_venv_if_present=true"
# Terminal 3: Training on GPU 1
cd experiments && CUDA_VISIBLE_DEVICES=1 CUDA_HOME=$HOME/env-claude-cu130/cuda_shim \
axolotl train nemo_gym_async_agent.yaml
```
**Step 3: Training config** (`nemo_gym_async_agent.yaml`):
```yaml
base_model: Qwen/Qwen3-0.6B
adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
sequence_len: 2048
rl: grpo
chat_template: tokenizer_default
trl:
use_vllm: true
vllm_mode: server
vllm_server_host: localhost
vllm_server_port: 8000
vllm_lora_sync: true
vllm_sync_interval: 5
# Async GRPO — 3x faster than standard
use_data_producer: true
async_prefetch: true
num_generations: 4
max_completion_length: 512
temperature: 0.8
reward_funcs:
- axolotl.integrations.nemo_gym.rewards.reward_env
plugins:
- axolotl.integrations.nemo_gym.NemoGymPlugin
nemo_gym_enabled: true
nemo_gym_auto_start: false
nemo_gym_head_port: 11000
nemo_gym_multi_turn: true
nemo_gym_verify_timeout: 120
nemo_gym_datasets:
- path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
server_name: example_single_tool_call
datasets:
- path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
type: chat_template
field_messages: responses_create_params.input
message_field_content: content
message_field_role: role
vllm:
gpu_memory_utilization: 0.85
max_model_len: 2048
tensor_parallel_size: 1
learning_rate: 5e-6
micro_batch_size: 1
gradient_accumulation_steps: 4
max_steps: 30
gradient_checkpointing: true
bf16: true
output_dir: ./outputs/nemo_gym_async
use_wandb: true
wandb_project: nemo-gym-rl
```
### 3. Single-Turn Training (Simplest — No Agent Server Needed)
For environments that only need single-turn verify (math, coding challenges), you don't need
an agent server. The plugin's reward function calls `/verify` directly.
```yaml
base_model: Qwen/Qwen2.5-0.5B-Instruct
rl: grpo
chat_template: tokenizer_default
trl:
use_vllm: true
vllm_mode: colocate
vllm_enable_sleep_mode: false
num_generations: 8
max_completion_length: 128
temperature: 0.9
reward_funcs:
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
plugins:
- axolotl.integrations.nemo_gym.NemoGymPlugin
nemo_gym_enabled: true
nemo_gym_auto_start: false
nemo_gym_head_port: 11000
nemo_gym_datasets:
- path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
server_name: reasoning_gym
datasets:
- path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
type: chat_template
field_messages: responses_create_params.input
message_field_content: content
message_field_role: role
vllm:
gpu_memory_utilization: 0.3
max_model_len: 512
tensor_parallel_size: 1
learning_rate: 1e-5
micro_batch_size: 4
gradient_accumulation_steps: 2
max_steps: 50
output_dir: ./outputs/nemo_gym_arithmetic
```
Only needs `ng_run` with resource servers (no agent config):
```bash
cd ~/Gym && ng_run "+config_paths=[resources_servers/reasoning_gym/configs/resources_only.yaml]" "+skip_venv_if_present=true"
```
## How It Works
### Single-Turn
```text
axolotl train → GRPO Trainer generates completions
→ NeMo Gym plugin reward_fn calls POST /verify on resource server
→ reward flows back to GRPO for advantage computation
```
### Multi-Turn (Agent /run)
```text
┌─────────────┐ ┌──────────────┐ ┌──────────────────┐
│ axolotl │ │ NeMo Gym │────▶│ vLLM OpenAI │
│ train │────▶│ Agent /run │◀────│ Server (GPU 0) │
│ (GPU 1) │ │ │ │ /v1/completions │
└─────────────┘ └──────┬───────┘ └──────────────────┘
┌──────────────┐
│ Resource │
│ Server │
│ (tools + │
│ verify) │
└─────────────┘
```
The agent server orchestrates the entire multi-turn loop:
1. Calls our vLLM server for model generation
2. Parses tool calls from model output
3. Executes tools against resource servers
4. Feeds tool results back to the model
5. Repeats until done, then calls /verify for reward
6. Returns token IDs + logprobs + reward to our rollout_func
### Data Producer Architecture (Multi-Turn)
When `nemo_gym_multi_turn: true`, the plugin automatically forces `use_data_producer: true`
which selects the `AxolotlAsyncGRPOTrainer`. The plugin then swaps the trainer's data
producer with `NemoGymDataProducer`, which:
1. Gets a prompt batch from the dataset iterator
2. Expands by `num_generations` (one agent call per rollout)
3. Calls NeMo Gym agents via async HTTP (`aiohttp.gather`)
4. Parses responses into padded tensors (`RolloutDataset`)
5. Returns with `_pending_policy_logps=True` for deferred scoring
The main thread then runs `_compute_deferred_scores()` which:
- Computes **policy logprobs** on the training model (GPU forward pass)
- Computes **IS correction** using agent's sampling logprobs vs training model logprobs
- Computes advantages with group-level normalization
- All downstream features work: replay buffer, re-roll, streaming, zero-adv skip
With `async_prefetch: true`, the data producer runs in a background thread — giving ~3x
speedup as generation and training overlap. With `async_prefetch: false`, it runs
synchronously on the main thread (still uses the data producer protocol).
### Weight Sync (LoRA Mode)
With `vllm_lora_sync: true`, the plugin (or async trainer) replaces NCCL-based weight
sync with filesystem + HTTP:
1. `accelerator.get_state_dict()` gathers LoRA weights from all ranks
2. Rank 0 saves adapter to `/tmp/lora_sync_*/vN/`
3. Rank 0 POSTs to `/set_lora_adapter/` on vLLM server
4. vLLM loads adapter natively via Punica kernels
5. Only ~40MB transferred (vs multiple GBs for full model weights)
### Multi-Environment Support
Datasets support per-row environment routing via `agent_ref`:
```jsonl
{"agent_ref": {"name": "reasoning_gym"}, "responses_create_params": {...}}
{"agent_ref": {"name": "instruction_following"}, "responses_create_params": {...}}
```
Or use the simpler per-dataset routing:
```yaml
nemo_gym_datasets:
- path: reasoning_data.jsonl
server_name: reasoning_gym
- path: tool_data.jsonl
server_name: example_single_tool_call
```
## Configuration Reference
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `nemo_gym_enabled` | bool | `null` | Enable the NeMo Gym integration |
| `nemo_gym_dir` | str | `~/Gym` | Path to NeMo Gym repo |
| `nemo_gym_auto_clone` | bool | `true` | Auto-clone NeMo Gym repo if missing |
| `nemo_gym_auto_start` | bool | `true` | Auto-start resource servers |
| `nemo_gym_config_paths` | list[str] | — | Server config YAMLs (relative to gym_dir) |
| `nemo_gym_datasets` | list[dict] | required | Dataset configs with `path` and optional `server_name` |
| `nemo_gym_head_port` | int | `11000` | Head server port |
| `nemo_gym_server_timeout` | int | `360` | Server startup timeout (seconds) |
| `nemo_gym_verify_timeout` | int | `30` | Per-request timeout (seconds) |
| `nemo_gym_multi_turn` | bool | `false` | Enable multi-turn via agent /run |
### Dataset JSONL Format
Each line must have `responses_create_params` with `input` messages:
```json
{
"responses_create_params": {
"input": [{"role": "user", "content": "What's the weather in SF?"}],
"tools": [{"name": "get_weather", "type": "function", "strict": true, "parameters": {...}}]
}
}
```
For multi-turn agent routing, include `agent_ref`:
```json
{"agent_ref": {"name": "my_agent"}, "responses_create_params": {...}}
```
Note: Tool definitions MUST include `"strict": true` and `"additionalProperties": false` for NeMo Gym agent compatibility.
### Reward Functions
The plugin provides two built-in reward functions — no user code needed:
```yaml
trl:
reward_funcs:
# Multi-turn (nemo_gym_multi_turn: true):
# Passthrough — agent /run already computed the reward
- axolotl.integrations.nemo_gym.rewards.reward_env
# Single-turn (nemo_gym_multi_turn: false):
# Calls /verify endpoints on NeMo Gym resource servers
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
```
Both are also importable from Python:
```python
from axolotl.integrations.nemo_gym import reward_env, reward_nemo_gym_verify
```
## Known Issues / Troubleshooting
### NeMo Gym Server Setup
- **pycosat build failure**: `CFLAGS="" uv pip install pycosat --no-build-isolation`
- **Ray version mismatch**: Pin `ray[default]==2.52.1` in all server venvs
- **Pre-build venvs**: `ng_run` creates per-server venvs via Ray. Pre-build them and use `+skip_venv_if_present=true`
- **Tool `strict` field required**: Agent server validates tool definitions require `strict: true`
### vLLM / Weight Sync
- **Start vLLM with LoRA + tool calling + runtime loading**:
```bash
VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 \
CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \
--model Qwen/Qwen3-4B-Instruct-2507 \
--max-model-len 4096 \
--gpu-memory-utilization 0.7 \
--enable-lora --max-lora-rank 64 \
--enable-auto-tool-choice --tool-call-parser hermes
```
- **`VLLM_ALLOW_RUNTIME_LORA_UPDATING=1`**: Required for `vllm_lora_sync: true`. Without it, vLLM won't expose the `/v1/load_lora_adapter` endpoint and weight sync will fail silently. The plugin warns if this endpoint is missing.
- **`--enable-lora`**: Enables LoRA adapter support in vLLM
- **`--enable-auto-tool-choice --tool-call-parser hermes`**: Required for Qwen3 tool calling
- **`max_model_len` must be > `max_completion_length`**: Leave room for prompt tokens (~200). If equal, the NeMo Gym model proxy gets a 400 error and returns empty completions.
- **`CUDA_HOME` required**: DeepSpeed import needs it for the nvcc shim
- **NCCL weight sync broken with vLLM 0.17**: Use `vllm_lora_sync: true` (filesystem + HTTP via `/v1/load_lora_adapter`)
### Multi-Turn
- **Agent server required**: Multi-turn delegates to NeMo Gym's agent server `/run` endpoint. Without an agent, the plugin falls back to single-turn `/verify`
- **Model server proxy**: NeMo Gym needs a `responses_api_models` server that proxies to your vLLM. See the agent config example above
### FSDP2
- Validated on 2 GPUs with single-turn + LoRA
- Async field filtering: The builder automatically filters async-only config fields when using the standard GRPO trainer
## Comparison with Other Integrations
| Feature | Axolotl + NeMo Gym | Unsloth + NeMo Gym | NeMo RL (native) |
|---------|-------------------|-------------------|-------------------|
| Server management | Automatic | Manual (notebook) | Built-in |
| Multi-environment | Per-row routing | Manual code | YAML config |
| Multi-turn / tool use | Agent /run delegation | No | Agent /run (Ray) |
| Async GRPO (3x speedup) | Yes | No | Yes |
| LoRA sync | Filesystem + HTTP | N/A | NCCL |
| Multi-GPU (FSDP2) | Yes | No | Yes (Ray) |
| Config-driven | Yes | No (code) | Yes |

View File

@@ -0,0 +1,25 @@
# Copyright 2026 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
"""
Plugin for NVIDIA NeMo Gym integration with Axolotl.
NeMo Gym provides RL training environments for LLMs with verification-based
reward signals. This plugin manages the NeMo Gym server lifecycle, loads
datasets in the NeMo Gym JSONL format, and creates reward functions that
call the NeMo Gym /verify endpoints.
"""
from .args import NemoGymArgs
from .plugin import NemoGymPlugin
from .rewards import reward_env, reward_nemo_gym_verify
__all__ = [
"NemoGymArgs",
"NemoGymPlugin",
"reward_env",
"reward_nemo_gym_verify",
]

View File

@@ -0,0 +1,146 @@
# Copyright 2026 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
"""
Input arguments for the NeMo Gym integration plugin.
"""
from pydantic import BaseModel, Field, model_validator
class NemoGymArgs(BaseModel):
"""Configuration args for the NeMo Gym integration."""
nemo_gym_enabled: bool | None = Field(
default=True,
json_schema_extra={
"description": "Enable NeMo Gym integration for environment-based RL rewards."
},
)
nemo_gym_dir: str | None = Field(
default=None,
json_schema_extra={
"description": (
"Path to the NeMo Gym repository clone. "
"If not set and nemo_gym_auto_clone is True, clones to ~/Gym."
)
},
)
nemo_gym_auto_clone: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Automatically clone the NeMo Gym repository if not present. "
"Defaults to True when nemo_gym_enabled is set."
)
},
)
nemo_gym_config_paths: list[str] | None = Field(
default=None,
json_schema_extra={
"description": (
"List of NeMo Gym resource server config YAML paths, relative to nemo_gym_dir. "
"Example: ['resources_servers/reasoning_gym/configs/resources_only.yaml']"
)
},
)
nemo_gym_head_port: int | None = Field(
default=11000,
json_schema_extra={
"description": "Port for the NeMo Gym head server. Defaults to 11000."
},
)
nemo_gym_server_timeout: int | None = Field(
default=360,
json_schema_extra={
"description": "Timeout in seconds waiting for NeMo Gym servers to start. Defaults to 360."
},
)
nemo_gym_verify_timeout: int | None = Field(
default=30,
json_schema_extra={
"description": "Timeout in seconds for individual /verify requests. Defaults to 30."
},
)
nemo_gym_run_timeout: int | None = Field(
default=300,
json_schema_extra={
"description": (
"Timeout in seconds for each agent /run request (one multi-turn rollout). "
"Prevents stuck generations (e.g. model looping on <think> tags) from "
"blocking training indefinitely. Defaults to 300 (5 minutes)."
)
},
)
nemo_gym_datasets: list[dict] | None = Field(
default=None,
json_schema_extra={
"description": (
"List of NeMo Gym dataset configs. Each entry has 'path' (JSONL file path "
"relative to nemo_gym_dir) and optionally 'server_name' (default resource server). "
"If the JSONL rows have agent_ref.name, that takes precedence per row, "
"enabling multi-environment training from a single dataset file. "
"Optional 'max_samples' to limit per dataset."
)
},
)
nemo_gym_auto_start: bool | None = Field(
default=True,
json_schema_extra={
"description": (
"Automatically start NeMo Gym resource servers. Defaults to True. "
"Set to False if servers are already running externally."
)
},
)
nemo_gym_model_name: str | None = Field(
default=None,
json_schema_extra={
"description": (
"Model name to report in verify requests. "
"Defaults to the base_model from the main config."
)
},
)
nemo_gym_multi_turn: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Enable multi-turn rollouts via NeMo Gym. When True, uses TRL's "
"rollout_func to run multi-step interactions with tool execution. "
"Requires use_vllm=True in TRL config. The model generates responses, "
"tool calls are executed against resource servers, and results are "
"fed back for the next turn. Final reward comes from /verify."
)
},
)
nemo_gym_max_turns: int | None = Field(
default=None,
json_schema_extra={
"description": (
"Maximum number of turns per multi-turn rollout. Defaults to 10. "
"Each turn consists of a model generation + optional tool execution."
)
},
)
@model_validator(mode="before")
@classmethod
def check_nemo_gym_config(cls, data):
if data.get("nemo_gym_enabled"):
if not data.get("nemo_gym_config_paths") and data.get(
"nemo_gym_auto_start", True
):
raise ValueError(
"nemo_gym_config_paths is required when nemo_gym_enabled=True "
"and nemo_gym_auto_start is not False."
)
if not data.get("nemo_gym_datasets"):
raise ValueError(
"nemo_gym_datasets is required when nemo_gym_enabled=True. "
"Provide at least one dataset with 'path' and 'server_name'."
)
return data

View File

@@ -0,0 +1,226 @@
# Copyright 2026 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
"""
NeMo Gym Data Producer for async GRPO training.
Replaces GRPODataProducer to generate rollouts via NeMo Gym agent /run endpoints
instead of vLLM. The agent handles generation, tool execution, and reward computation.
Returns RolloutDataset in the same format as the standard producer, so all downstream
components (deferred scoring, IS correction, streaming, replay, re-roll) work unchanged.
"""
from __future__ import annotations
import asyncio
from typing import Any
import torch
from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.async_trainer import GRPODataProducer, RolloutDataset
from axolotl.utils.logging import get_logger
from .multi_turn import _call_agents, _parse_agent_response
LOG = get_logger(__name__)
class NemoGymDataProducer(GRPODataProducer):
"""Produces GRPO rollouts by calling NeMo Gym agent /run endpoints.
Drop-in replacement for GRPODataProducer. Instead of calling vLLM for generation,
sends prompts to NeMo Gym agents which handle generation + tool execution + reward.
Returns the same RolloutDataset format so deferred scoring, IS correction,
replay buffer, and re-roll all work unchanged.
"""
def __init__(
self,
*args,
agent_servers: dict[str, str],
dataset_lookup: dict,
request_timeout: float = 10800,
**kwargs,
):
super().__init__(*args, **kwargs)
self._agent_servers = agent_servers
self._dataset_lookup = dataset_lookup
self._request_timeout = request_timeout
def produce(
self,
model: Any,
global_step: int,
*,
skip_policy_logps: bool = False,
processing_class: Any = None,
accelerator: Any = None,
args: Any = None,
_rank0_only: bool = False,
**kwargs,
) -> RolloutDataset | None:
"""Generate rollouts via NeMo Gym agents.
Calls agent /run endpoints, parses responses into padded tensors,
and returns a RolloutDataset for deferred scoring on the main thread.
"""
trainer = self._trainer
is_main = trainer.accelerator.is_main_process
device = trainer.accelerator.device
if _rank0_only and not is_main:
return None
# Get prompt batch from iterator
try:
inputs = next(self._prompt_iter)
except StopIteration:
self._prompt_iter = iter(self._prompt_dl)
inputs = next(self._prompt_iter)
# Extract dataset items for agent calls
dataset_items = []
for inp in inputs:
prompt_text = ""
prompt = inp.get("prompt", [])
if isinstance(prompt, list) and prompt:
prompt_text = (
prompt[-1].get("content", "")
if isinstance(prompt[-1], dict)
else str(prompt[-1])
)
elif isinstance(prompt, str):
prompt_text = prompt
# Find the full dataset item, preserving agent_ref for routing
full_item = self._dataset_lookup.get(prompt_text, {})
item = full_item.get("verify_extra", {})
if not item:
item = {
"responses_create_params": {
"input": [{"role": "user", "content": prompt_text}]
}
}
# Preserve agent_ref from the dataset row for _call_agents routing
if "agent_ref" in full_item and "agent_ref" not in item:
item["agent_ref"] = full_item["agent_ref"]
dataset_items.append(item)
# Expand by num_generations (agent produces one rollout per call)
expanded_items = []
for item in dataset_items:
for _ in range(self._num_generations):
expanded_items.append(item)
# Call NeMo Gym agents
loop = asyncio.new_event_loop()
try:
responses = loop.run_until_complete(
_call_agents(
dataset_items=expanded_items,
agent_servers=self._agent_servers,
timeout=self._request_timeout,
max_completion_length=trainer.max_completion_length,
temperature=trainer.temperature,
top_p=getattr(trainer, "top_p", None) or 0.999,
)
)
finally:
loop.close()
# Parse responses
eos_token_id = trainer.processing_class.eos_token_id
prompt_ids_list = []
completion_ids_list = []
env_mask_list = []
logprobs_list = []
rewards_list = []
for resp in responses:
parsed = _parse_agent_response(resp, eos_token_id)
prompt_ids_list.append(parsed["prompt_ids"])
completion_ids_list.append(parsed["completion_ids"])
env_mask_list.append(parsed["env_mask"])
logprobs_list.append(parsed["logprobs"])
rewards_list.append(parsed["reward"])
# Pad to tensors
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
prompt_ids = pad(
prompt_ids, padding_value=trainer.pad_token_id, padding_side="left"
)
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
completion_ids = [
torch.tensor(ids, device=device) for ids in completion_ids_list
]
completion_mask = [
torch.ones_like(ids, dtype=torch.long) for ids in completion_ids
]
completion_ids = pad(
completion_ids, padding_value=trainer.pad_token_id, padding_side="right"
)
completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
# Sampling logprobs from agent (used for IS correction)
sampling_logps = [
torch.tensor(lp, dtype=torch.float32, device=device) for lp in logprobs_list
]
sampling_per_token_logps = pad(
sampling_logps, padding_value=0.0, padding_side="right"
)
# env_mask as tool_mask (1=model tokens, 0=tool tokens)
tool_mask = [torch.tensor(m, device=device) for m in env_mask_list]
tool_mask = pad(tool_mask, padding_value=1, padding_side="right")
# Inject rewards into inputs so _compute_deferred_scores can use them
# The deferred scoring path calls _calculate_rewards which reads reward_funcs.
# Our passthrough reward_fn reads "env_reward" from kwargs.
for i, inp in enumerate(inputs):
# Each input gets rewards for its num_generations rollouts
start = i * self._num_generations
end = start + self._num_generations
inp["env_reward"] = rewards_list[start:end]
# Expand inputs to match expanded rollouts (num_generations copies)
expanded_inputs = []
for inp in inputs:
for g in range(self._num_generations):
expanded_inp = dict(inp)
expanded_inp["env_reward"] = inp["env_reward"][g]
expanded_inputs.append(expanded_inp)
# Decode completions for reward functions
completions = trainer.processing_class.batch_decode(
completion_ids, skip_special_tokens=True
)
# Build total token count
num_items_in_batch = completion_mask.sum()
# Build output dict (same shape as _generate_only)
output = {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"num_items_in_batch": num_items_in_batch,
"advantages": torch.zeros(completion_ids.size(0), device=device),
"sampling_per_token_logps": sampling_per_token_logps,
"tool_mask": tool_mask,
# Deferred scoring markers
"_pending_policy_logps": True,
"_deferred_inputs": expanded_inputs,
"_deferred_prompts": [inp.get("prompt", "") for inp in expanded_inputs],
"_deferred_completions": completions,
"_deferred_completion_ids_list": completion_ids_list,
"_rank0_only": _rank0_only,
}
return RolloutDataset(output)

View File

@@ -0,0 +1,135 @@
# Copyright 2026 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
"""
Dataset loading for NeMo Gym JSONL files.
Converts NeMo Gym JSONL format into HuggingFace Datasets compatible
with TRL's GRPOTrainer. Supports multi-environment routing via:
1. Per-dataset server_name (all rows in a file go to one server)
2. Per-row agent_ref.name (each row specifies its own server)
"""
from __future__ import annotations
import json
import os
import random
from datasets import Dataset
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def load_nemo_gym_datasets(
gym_dir: str,
dataset_configs: list[dict],
) -> Dataset:
"""Load and merge NeMo Gym JSONL datasets with multi-environment support.
Each dataset config should have:
- path: JSONL file path (absolute, or relative to gym_dir)
- server_name: Default NeMo Gym server for this dataset.
Can be overridden per-row if the JSONL has an "agent_ref" field.
- max_samples (optional): Max number of samples to use from this dataset
Per-row routing: If a JSONL row has an "agent_ref": {"name": "..."} field,
that takes precedence over the dataset-level server_name. This allows mixing
environments within a single dataset file (matching TRL's pattern).
The output dataset has columns:
- prompt: list[dict] chat format
- resources_server_ref: dict with {"name": server_name}
- verify_extra: dict with original JSONL data for verify requests
Args:
gym_dir: Path to the NeMo Gym directory.
dataset_configs: List of dataset configuration dicts.
Returns:
A HuggingFace Dataset ready for GRPOTrainer.
"""
all_examples = []
for ds_cfg in dataset_configs:
path = ds_cfg["path"]
default_server = ds_cfg.get("server_name", "")
max_samples = ds_cfg.get("max_samples")
# Resolve path
if not os.path.isabs(path):
path = os.path.join(gym_dir, path)
path = os.path.expanduser(path)
if not os.path.exists(path):
raise FileNotFoundError(
f"NeMo Gym dataset not found at {path}. "
"Ensure the dataset file exists or run the appropriate "
"NeMo Gym dataset creation script."
)
LOG.info(
f"Loading NeMo Gym dataset from {path} (default server: {default_server})"
)
with open(path, encoding="utf-8") as f:
lines = f.readlines()
if max_samples and len(lines) > max_samples:
lines = random.sample(lines, max_samples) # nosec B311
for line in lines:
data = json.loads(line)
# Extract user prompt from the input messages
inputs = data.get("responses_create_params", {}).get("input", [])
task_prompt = ""
for inp in inputs:
if isinstance(inp, dict) and inp.get("role") in ("user",):
task_prompt = inp.get("content", "")
break
if not task_prompt and inputs:
# Fallback: use the last input's content
task_prompt = (
inputs[-1].get("content", "")
if isinstance(inputs[-1], dict)
else ""
)
# Per-row agent routing: agent_ref.name can override dataset-level server_name.
# NeMo Gym datasets may use agent names (e.g., "reasoning_gym_simple_agent")
# which differ from resource server names (e.g., "reasoning_gym").
# The dataset-level server_name is always the fallback.
row_agent_ref = data.get("agent_ref", {})
server_name = default_server
if row_agent_ref and row_agent_ref.get("name"):
# Use per-row name, but only if it looks like a resource server name.
# Agent names typically have "_simple_agent" or "_agent" suffix.
row_name = row_agent_ref["name"]
if row_agent_ref.get("type") != "responses_api_agents":
# Not an agent — could be a direct resource server reference
server_name = row_name
all_examples.append(
{
"prompt": [{"role": "user", "content": task_prompt}],
"resources_server_ref": {"name": server_name},
"verify_extra": data,
}
)
random.shuffle(all_examples)
# Log environment distribution
env_counts: dict[str, int] = {}
for ex in all_examples:
name = ex["resources_server_ref"]["name"]
env_counts[name] = env_counts.get(name, 0) + 1
LOG.info(f"Loaded {len(all_examples)} NeMo Gym examples: {env_counts}")
return Dataset.from_list(all_examples)

View File

@@ -0,0 +1,64 @@
# Axolotl + NeMo Gym: Multi-Environment RL Training Example
#
# Trains on multiple NeMo Gym environments simultaneously:
# - Mini Sudoku (reasoning_gym)
# - Instruction Following
#
# Prerequisites:
# - uv (https://github.com/astral-sh/uv) installed
# - Download instruction_following.jsonl from nvidia/Nemotron-RL-instruction_following
# and place it at ~/Gym/resources_servers/instruction_following/data/instruction_following.jsonl
#
# Usage:
# axolotl train examples/nemo_gym_multi_env.yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
model_type: AutoModelForCausalLM
sequence_len: 4096
load_in_4bit: false
# RL configuration
rl: grpo
chat_template: tokenizer_default
# GRPO / TRL settings
trl:
use_vllm: false
num_generations: 8
max_completion_length: 2048
# NeMo Gym plugin
plugins:
- axolotl.integrations.nemo_gym.NemoGymPlugin
nemo_gym_enabled: true
nemo_gym_dir: ~/Gym
nemo_gym_auto_clone: true
nemo_gym_auto_start: true
nemo_gym_config_paths:
- resources_servers/reasoning_gym/configs/resources_only.yaml
- resources_servers/instruction_following/configs/instruction_following.yaml
nemo_gym_datasets:
- path: resources_servers/reasoning_gym/data/train_mini_sudoku.jsonl
server_name: reasoning_gym
max_samples: 1000
- path: resources_servers/instruction_following/data/instruction_following.jsonl
server_name: instruction_following
max_samples: 1000
# Training hyperparameters
learning_rate: 1.0e-5
weight_decay: 0.001
lr_scheduler: linear
warmup_ratio: 0.0
optimizer: adamw_8bit
num_epochs: 1
max_steps: 100
micro_batch_size: 1
gradient_accumulation_steps: 64
logging_steps: 1
save_steps: 100
output_dir: ./outputs/nemo_gym_multi_env

View File

@@ -0,0 +1,79 @@
# Axolotl + NeMo Gym: Multi-Turn Tool-Use RL Training Example
#
# This config trains a model with multi-turn reinforcement learning
# using NeMo Gym's agentic environments. The model generates responses,
# executes tool calls against NeMo Gym resource servers, and receives
# environment feedback across multiple turns.
#
# Multi-turn mode uses TRL's rollout_func with env_mask to ensure only
# model-generated tokens contribute to the training loss.
#
# Prerequisites:
# - uv (https://github.com/astral-sh/uv) installed
# - vLLM must be enabled (required for generate_rollout_completions)
#
# Usage:
# # Terminal 1: Start vLLM server
# trl vllm-serve --model Qwen/Qwen2.5-1.5B-Instruct
#
# # Terminal 2: Train
# axolotl train examples/nemo_gym_multi_turn.yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
model_type: AutoModelForCausalLM
sequence_len: 4096
load_in_4bit: false
# RL configuration
rl: grpo
chat_template: tokenizer_default
# GRPO / TRL settings — vLLM required for multi-turn
trl:
use_vllm: true
vllm_mode: server
vllm_server_host: localhost
vllm_server_port: 8000
num_generations: 4
max_completion_length: 2048
# NeMo Gym plugin with multi-turn enabled
plugins:
- axolotl.integrations.nemo_gym.NemoGymPlugin
nemo_gym_enabled: true
nemo_gym_dir: ~/Gym
nemo_gym_auto_clone: true
nemo_gym_auto_start: true
nemo_gym_head_port: 11000
nemo_gym_server_timeout: 360
nemo_gym_verify_timeout: 30
# Multi-turn settings
nemo_gym_multi_turn: true
nemo_gym_max_turns: 10
# Resource server configs — use environments that support tool calls
nemo_gym_config_paths:
- resources_servers/reasoning_gym/configs/resources_only.yaml
nemo_gym_datasets:
- path: resources_servers/reasoning_gym/data/train_mini_sudoku.jsonl
server_name: reasoning_gym
max_samples: 2000
# Training hyperparameters
learning_rate: 1.0e-5
weight_decay: 0.001
lr_scheduler: linear
warmup_ratio: 0.0
optimizer: adamw_8bit
num_epochs: 1
max_steps: 100
micro_batch_size: 1
gradient_accumulation_steps: 64
logging_steps: 1
save_steps: 100
output_dir: ./outputs/nemo_gym_multi_turn

View File

@@ -0,0 +1,62 @@
# Axolotl + NeMo Gym: Sudoku RL Training Example
#
# This config trains a model to solve 4x4 Mini Sudoku puzzles using GRPO
# with NeMo Gym's reasoning_gym environment as the reward signal.
#
# Prerequisites:
# - uv (https://github.com/astral-sh/uv) installed
# - Git installed
# - The NeMo Gym repo will be auto-cloned to ~/Gym on first run
#
# Usage:
# axolotl train examples/nemo_gym_sudoku.yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
model_type: AutoModelForCausalLM
sequence_len: 4096
load_in_4bit: false
# RL configuration
rl: grpo
chat_template: tokenizer_default
# GRPO / TRL settings
trl:
use_vllm: false
num_generations: 8
max_completion_length: 2048
# NeMo Gym plugin
plugins:
- axolotl.integrations.nemo_gym.NemoGymPlugin
nemo_gym_enabled: true
nemo_gym_dir: ~/Gym
nemo_gym_auto_clone: true
nemo_gym_auto_start: true
nemo_gym_head_port: 11000
nemo_gym_server_timeout: 360
nemo_gym_verify_timeout: 30
nemo_gym_config_paths:
- resources_servers/reasoning_gym/configs/resources_only.yaml
nemo_gym_datasets:
- path: resources_servers/reasoning_gym/data/train_mini_sudoku.jsonl
server_name: reasoning_gym
max_samples: 2000
# Training hyperparameters
learning_rate: 1.0e-5
weight_decay: 0.001
lr_scheduler: linear
warmup_ratio: 0.0
optimizer: adamw_8bit
num_epochs: 1
max_steps: 100
micro_batch_size: 1
gradient_accumulation_steps: 64
logging_steps: 1
save_steps: 100
output_dir: ./outputs/nemo_gym_sudoku

View File

@@ -0,0 +1,329 @@
# Copyright 2026 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
"""
Multi-turn rollout function for NeMo Gym environments.
Delegates multi-turn orchestration to NeMo Gym's agent servers via the /run
endpoint. The agent handles generation (by calling our vLLM server), tool
execution, session management, and reward computation.
This follows the same pattern as TRL's reference implementation at
examples/scripts/nemo_gym/train_multi_environment.py.
Architecture:
rollout_func(prompts, trainer)
-> expand prompts by num_generations
-> async POST /run to agent servers (one per sample)
-> parse response: prompt_ids, completion_ids, logprobs, env_mask, reward
-> return to TRL for GRPO training
"""
from __future__ import annotations
import asyncio
from typing import Any
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def create_nemo_gym_rollout_func(
agent_servers: dict[str, str],
dataset_lookup: dict[int, dict],
request_timeout: float = 10800,
):
"""Create a TRL-compatible rollout_func that delegates to NeMo Gym agents.
Args:
agent_servers: Mapping of agent_name → agent URL (e.g., {"simple_agent": "http://host:port"}).
dataset_lookup: Mapping of dataset index → full JSONL row dict.
request_timeout: HTTP timeout for /run requests.
Returns:
A rollout_func with signature (prompts: list[str], trainer) -> dict.
"""
def rollout_func(prompts: list[str], trainer) -> dict[str, Any]:
is_training = trainer.model.training
num_generations = (
trainer.num_generations
if is_training
else getattr(trainer, "num_generations_eval", 1)
)
temperature = trainer.temperature
top_p = getattr(trainer, "top_p", None) or 0.999
max_completion_length = trainer.max_completion_length
eos_token_id = trainer.processing_class.eos_token_id
# Expand prompts: each prompt index repeated num_generations times
expanded_items = []
expanded_prompt_indices = []
for prompt_str in prompts:
# Prompts from TRL are chat-templated strings. Find the dataset item
# by matching against dataset_lookup keys (raw user message text).
full_item = None
for key, val in dataset_lookup.items():
if isinstance(key, str) and prompt_str == key:
full_item = val
break
if full_item is None:
full_item = {
"responses_create_params": {
"input": [{"role": "user", "content": prompt_str}]
}
}
for _ in range(num_generations):
# Preserve agent_ref for routing in _call_agents
dispatched: dict = full_item.get("verify_extra", full_item) # type: ignore[assignment]
if isinstance(dispatched, dict) and "agent_ref" not in dispatched:
agent_ref = full_item.get("agent_ref")
if agent_ref:
dispatched = {**dispatched, "agent_ref": agent_ref}
expanded_items.append(dispatched)
expanded_prompt_indices.append(prompt_str)
# Call NeMo Gym agents
loop = asyncio.new_event_loop()
try:
responses = loop.run_until_complete(
_call_agents(
dataset_items=expanded_items,
agent_servers=agent_servers,
timeout=request_timeout,
max_completion_length=max_completion_length,
temperature=temperature,
top_p=top_p,
)
)
finally:
loop.close()
# Parse responses into rollout format
all_prompt_ids = []
all_completion_ids = []
all_env_masks = []
all_logprobs = []
all_rewards = []
all_num_turns = []
for _i, response in enumerate(responses):
result = _parse_agent_response(response, eos_token_id)
all_prompt_ids.append(result["prompt_ids"])
all_completion_ids.append(result["completion_ids"])
all_env_masks.append(result["env_mask"])
all_logprobs.append(result["logprobs"])
all_rewards.append(result["reward"])
all_num_turns.append(result["num_turns"])
# TRL expects prompt_ids to be unique (one per original prompt, not per generation)
unique_prompt_ids = all_prompt_ids[::num_generations]
# Wrap logprobs for TRL: list[list[list[float]]]
def _normalize(lp):
while isinstance(lp, (list, tuple)) and len(lp) > 0:
lp = lp[0]
return float(lp) if lp is not None else 0.0
wrapped_logprobs = [[[_normalize(lp)] for lp in seq] for seq in all_logprobs]
return {
"prompt_ids": unique_prompt_ids,
"completion_ids": all_completion_ids,
"env_mask": all_env_masks,
"logprobs": wrapped_logprobs,
"logprob_token_ids": None, # nosec B105
"env_reward": all_rewards,
"num_turns": all_num_turns,
}
return rollout_func
async def _call_agents(
dataset_items: list[dict],
agent_servers: dict[str, str],
timeout: float,
max_completion_length: int = 4096,
temperature: float = 1.0,
top_p: float = 0.999,
) -> list[dict]:
"""Async batch POST to NeMo Gym agent /run endpoints."""
import aiohttp
results = []
connector = aiohttp.TCPConnector(limit_per_host=64, limit=256)
# Use sock_read for per-request timeout (not total session timeout).
# This ensures a single stuck generation doesn't block all other requests.
client_timeout = aiohttp.ClientTimeout(total=None, sock_read=timeout)
async with aiohttp.ClientSession(
connector=connector, timeout=client_timeout, cookie_jar=aiohttp.DummyCookieJar()
) as session:
tasks = []
for item in dataset_items:
agent_ref = item.get("agent_ref", {})
agent_name = agent_ref.get("name", "")
agent_url = agent_servers.get(agent_name, "")
if not agent_url:
# Fallback: try first available agent
if agent_servers:
agent_url = next(iter(agent_servers.values()))
else:
results.append(
{
"response": {"output": []},
"reward": 0.0,
"error": "No agent server",
}
)
continue
# Build request body
request_body = dict(item)
params = request_body.setdefault("responses_create_params", {})
params.setdefault("max_output_tokens", max_completion_length)
params["temperature"] = temperature
params["top_p"] = top_p
tasks.append(_post_run(session, agent_url, request_body))
if tasks:
responses = await asyncio.gather(*tasks, return_exceptions=True)
for resp in responses:
if isinstance(resp, BaseException):
LOG.warning(f"Agent /run failed: {resp}")
results.append(
{"response": {"output": []}, "reward": 0.0, "error": str(resp)}
)
else:
results.append(resp)
return results
async def _post_run(session, agent_url: str, body: dict) -> dict:
"""POST to agent /run endpoint."""
async with session.post(f"{agent_url}/run", json=body) as resp:
if resp.status == 200:
return await resp.json()
text = await resp.text()
return {
"response": {"output": []},
"reward": 0.0,
"error": f"HTTP {resp.status}: {text[:200]}",
}
def _parse_agent_response(response: dict, eos_token_id: int) -> dict:
"""Parse NeMo Gym agent /run response into rollout format.
The agent returns:
response.output[]: list of turns, each with prompt_token_ids,
generation_token_ids, generation_log_probs
reward: float
"""
# Defaults for failed/empty responses
defaults = {
"prompt_ids": [eos_token_id],
"completion_ids": [eos_token_id],
"env_mask": [0],
"logprobs": [0.0],
"reward": 0.0,
"num_turns": 0,
}
if not isinstance(response, dict) or "error" in response:
return defaults
output_items = response.get("response", {}).get("output", [])
reward = float(response.get("reward", 0.0))
if not output_items:
defaults["reward"] = reward
return defaults
# Check at least one valid output
has_valid = False
for item in output_items:
if item.get("type") == "function_call":
has_valid = True
break
if item.get("type") == "message":
for c in item.get("content", []):
if (
isinstance(c, dict)
and c.get("type") == "output_text"
and c.get("text", "").strip()
):
has_valid = True
break
if has_valid:
break
if not has_valid:
defaults["reward"] = reward
return defaults
# Extract multi-turn token sequences
first_prompt_ids = None
completion_ids = []
env_mask = []
logprobs = []
seen_token_ids = []
num_turns = 0
for item in output_items:
prompt_token_ids = item.get("prompt_token_ids", [])
generation_token_ids = item.get("generation_token_ids", [])
generation_log_probs = item.get("generation_log_probs", [])
if not generation_token_ids:
continue
num_turns += 1
# First turn: capture prompt
if first_prompt_ids is None:
first_prompt_ids = list(prompt_token_ids)
seen_token_ids = list(prompt_token_ids)
else:
# Subsequent turns: extract tool result tokens (between turns)
if len(prompt_token_ids) > len(seen_token_ids):
tool_result_tokens = prompt_token_ids[len(seen_token_ids) :]
# Tool result tokens are NOT trained on (env_mask = 0)
completion_ids.extend(tool_result_tokens)
env_mask.extend([0] * len(tool_result_tokens))
logprobs.extend([0.0] * len(tool_result_tokens))
# Add generation tokens (trained on, env_mask = 1)
completion_ids.extend(generation_token_ids)
env_mask.extend([1] * len(generation_token_ids))
# Pad logprobs if shorter than generation tokens
gen_logprobs = list(generation_log_probs) if generation_log_probs else []
if len(gen_logprobs) < len(generation_token_ids):
gen_logprobs.extend([0.0] * (len(generation_token_ids) - len(gen_logprobs)))
logprobs.extend(gen_logprobs[: len(generation_token_ids)])
# Update seen tokens
seen_token_ids = list(prompt_token_ids) + list(generation_token_ids)
if first_prompt_ids is None:
return defaults
return {
"prompt_ids": first_prompt_ids,
"completion_ids": completion_ids,
"env_mask": env_mask,
"logprobs": logprobs,
"reward": reward,
"num_turns": num_turns,
}

View File

@@ -0,0 +1,503 @@
# Copyright 2026 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
"""
NeMo Gym Plugin for Axolotl.
Integrates NVIDIA NeMo Gym environments as reward sources for GRPO training.
Handles server lifecycle, dataset loading, and reward function wiring.
Supports two modes:
- Single-turn (default): reward_fn calls /verify after each generation
- Multi-turn (nemo_gym_multi_turn: true): rollout_func orchestrates
multi-step interactions with tool execution via resource servers
"""
from __future__ import annotations
import os
from typing import TYPE_CHECKING, Union
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from axolotl.common.datasets import TrainDatasetMeta
LOG = get_logger(__name__)
class NemoGymPlugin(BasePlugin):
"""Plugin for NVIDIA NeMo Gym integration with Axolotl.
When enabled, this plugin:
1. Clones and sets up the NeMo Gym repo (if needed)
2. Starts NeMo Gym resource servers
3. Loads datasets from NeMo Gym JSONL files
4. For single-turn: creates a reward function calling /verify
5. For multi-turn: creates a rollout_func with tool execution and env_mask
"""
def __init__(self):
super().__init__()
self._gym_dir = None
self._global_config = None
self._verify_endpoints = None
self._server_base_urls = None
self._reward_fn = None
self._dataset_lookup = None
self._agent_servers = {}
def get_input_args(self):
return "axolotl.integrations.nemo_gym.NemoGymArgs"
def pre_model_load(self, cfg):
"""Apply monkeypatches before trainer creation."""
if not cfg.nemo_gym_enabled:
return
# Always skip NCCL communicator init in NeMo Gym mode.
# NeMo Gym uses its own vLLM server (standard OpenAI API), not the TRL
# colocate/NCCL path. The NCCL init fails with vLLM V1 and standard servers.
trl_cfg = getattr(cfg, "trl", None)
if trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server":
self._patch_skip_nccl_init()
def _patch_skip_nccl_init(self):
"""Monkeypatch VLLMClient.init_communicator to no-op.
NeMo Gym uses its own vLLM server (standard OpenAI API or custom LoRA
serve script). The NCCL communicator is not needed and fails with both
vLLM V1 engine and standard OpenAI server mode.
"""
try:
from trl.generation.vllm_client import VLLMClient
VLLMClient._original_init_communicator = VLLMClient.init_communicator
VLLMClient.init_communicator = lambda self, **kwargs: LOG.info(
"Skipping NCCL init_communicator (LoRA sync mode)"
)
LOG.info("Patched VLLMClient.init_communicator to no-op for LoRA sync")
except Exception as exc:
LOG.warning(f"Failed to patch VLLMClient: {exc}")
def register(self, cfg):
if not cfg.get("nemo_gym_enabled"):
return
LOG.info("NeMo Gym integration enabled")
gym_dir = cfg.get("nemo_gym_dir") or os.path.expanduser("~/Gym")
auto_clone = cfg.get("nemo_gym_auto_clone", True)
auto_start = cfg.get("nemo_gym_auto_start", True)
head_port = cfg.get("nemo_gym_head_port", 11000)
server_timeout = cfg.get("nemo_gym_server_timeout", 360)
from .server import (
ensure_gym_repo,
ensure_gym_venv,
get_agent_servers,
get_server_base_url,
get_server_configs,
get_verify_endpoint,
start_servers,
wait_for_resource_servers,
)
self._gym_dir = ensure_gym_repo(gym_dir, auto_clone=auto_clone)
if auto_start:
config_paths = cfg.get("nemo_gym_config_paths", [])
ensure_gym_venv(self._gym_dir)
start_servers(
self._gym_dir,
config_paths,
head_port=head_port,
timeout=server_timeout,
)
self._global_config = get_server_configs(head_port=head_port)
wait_for_resource_servers(self._global_config, timeout=server_timeout)
# Build endpoint maps for resource servers (/verify)
self._verify_endpoints = {}
self._server_base_urls = {}
for server_name in self._global_config:
try:
self._verify_endpoints[server_name] = get_verify_endpoint(
self._global_config, server_name
)
self._server_base_urls[server_name] = get_server_base_url(
self._global_config, server_name
)
except (ValueError, KeyError, TypeError):
pass
# Discover agent servers (/run) for multi-turn
self._agent_servers = get_agent_servers(self._global_config)
# Pre-build dataset lookup for multi-turn (needs to happen at register time,
# not load_datasets, because load_datasets may not be called if axolotl config
# has its own datasets field)
if cfg.get("nemo_gym_multi_turn") and cfg.get("nemo_gym_datasets"):
from .dataset import load_nemo_gym_datasets
gym_dir = cfg.get("nemo_gym_dir") or os.path.expanduser("~/Gym")
dataset = load_nemo_gym_datasets(gym_dir, cfg["nemo_gym_datasets"])
self._dataset_lookup = {}
for i in range(len(dataset)):
row = dataset[i]
# Use last message content as key (matches data_producer lookup)
prompt_text = row["prompt"][-1]["content"]
self._dataset_lookup[prompt_text] = row
LOG.info(f"Built dataset lookup with {len(self._dataset_lookup)} entries")
multi_turn = cfg.get("nemo_gym_multi_turn", False)
LOG.info(
f"NeMo Gym ready with servers: {list(self._verify_endpoints.keys())} "
f"(multi_turn={'enabled' if multi_turn else 'disabled'})"
)
def load_datasets(self, cfg, preprocess=False) -> Union["TrainDatasetMeta", None]:
if not cfg.nemo_gym_enabled:
return None
from axolotl.common.datasets import TrainDatasetMeta
from .dataset import load_nemo_gym_datasets
dataset_configs = cfg.nemo_gym_datasets
dataset = load_nemo_gym_datasets(self._gym_dir, dataset_configs)
# Build prompt → row lookup for multi-turn rollout_func
# (rollout_func only receives prompt text, needs to look up row data)
self._dataset_lookup = {}
for i in range(len(dataset)):
row = dataset[i]
# Use last message content as key (matches data_producer lookup)
prompt_text = row["prompt"][-1]["content"]
self._dataset_lookup[prompt_text] = row
return TrainDatasetMeta(
train_dataset=dataset,
eval_dataset=None,
total_num_steps=0, # computed later by the builder
)
def get_training_args(self, cfg):
"""Pass through vLLM settings and force async trainer for multi-turn."""
args = {}
# Pass vLLM settings from vllm config block to TRL training args
if cfg.vllm:
vllm_cfg = cfg.vllm
max_len = getattr(vllm_cfg, "max_model_len", None)
gpu_util = getattr(vllm_cfg, "gpu_memory_utilization", None)
tp_size = getattr(vllm_cfg, "tensor_parallel_size", None)
if max_len:
args["vllm_max_model_length"] = max_len
if gpu_util:
args["vllm_gpu_memory_utilization"] = gpu_util
if tp_size:
args["vllm_tensor_parallel_size"] = tp_size
# Force async trainer for multi-turn: NemoGymDataProducer needs the
# data producer protocol. Setting use_data_producer=True selects
# AxolotlAsyncGRPOTrainer which supports _create_data_producer().
# With async_prefetch=False this runs synchronously — no threading.
if cfg.nemo_gym_multi_turn and self._agent_servers:
args["use_data_producer"] = True
LOG.info(
"NeMo Gym multi-turn: forcing use_data_producer=True for data producer protocol"
)
# Dataloader workers fork subprocesses that can't handle the async
# HTTP connections to NeMo Gym agents. Force num_workers=0.
if getattr(cfg, "dataloader_num_workers", None) not in (None, 0):
LOG.warning(
f"NeMo Gym: overriding dataloader_num_workers={cfg.dataloader_num_workers} → 0 "
"(forked workers can't use NeMo Gym agent connections)"
)
cfg.dataloader_num_workers = 0
if args:
LOG.info(f"NeMo Gym plugin injecting training args: {args}")
return args if args else None
def post_trainer_create(self, cfg, trainer):
"""Wire NeMo Gym into the trainer (reward_fn or rollout_func)."""
if not cfg.nemo_gym_enabled:
return
model_name = cfg.nemo_gym_model_name or cfg.base_model or "axolotl-model"
verify_timeout = cfg.nemo_gym_verify_timeout or 30
multi_turn = cfg.nemo_gym_multi_turn or False
# Handle weight sync. NeMo Gym skips NCCL init, so we need to either:
# - Install LoRA sync (when vllm_lora_sync=True)
# - Or no-op sync_weights (when using standard vLLM server)
trl_cfg = getattr(cfg, "trl", None)
if hasattr(trainer, "vllm_generation") and trainer.vllm_generation:
vllm_gen = trainer.vllm_generation
if trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False):
self._setup_lora_sync(trainer)
# Verify the vLLM server supports runtime LoRA loading
self._check_lora_endpoint(vllm_gen)
else:
# No NCCL, no LoRA sync — skip all weight sync paths
vllm_gen.sync_weights = lambda: LOG.debug(
"Weight sync skipped (NeMo Gym mode)"
)
type(vllm_gen).sync_weights = lambda self: LOG.debug(
"Weight sync skipped (NeMo Gym mode)"
)
# Also patch the async trainer's internal sync method
if hasattr(trainer, "_maybe_sync_vllm_weights"):
trainer._maybe_sync_vllm_weights = lambda: LOG.debug(
"Async weight sync skipped (NeMo Gym mode)"
)
LOG.info("Disabled weight sync (NeMo Gym mode, no LoRA sync)")
if multi_turn:
self._wire_multi_turn(cfg, trainer, model_name, verify_timeout)
else:
self._wire_single_turn(trainer, model_name, verify_timeout)
def _wire_single_turn(self, trainer, model_name, verify_timeout):
"""Inject single-turn reward function into the trainer."""
from .rewards import create_nemo_gym_reward_fn
self._reward_fn = create_nemo_gym_reward_fn(
global_config=self._global_config,
verify_endpoints=self._verify_endpoints,
model_name=model_name,
verify_timeout=verify_timeout,
)
if hasattr(trainer, "reward_funcs"):
trainer.reward_funcs.append(self._reward_fn)
trainer.reward_func_names.append("nemo_gym")
trainer.reward_processing_classes.append(None)
LOG.info(
f"Added NeMo Gym reward function (single-turn). "
f"Total reward functions: {len(trainer.reward_funcs)}"
)
else:
LOG.warning(
"Trainer does not have reward_funcs attribute. "
"NeMo Gym reward function not injected. "
"Ensure you are using a GRPO trainer."
)
def _wire_multi_turn(self, cfg, trainer, model_name, verify_timeout):
"""Replace the data producer with NemoGymDataProducer.
The plugin forces use_data_producer=True (in get_training_args) which
selects AxolotlAsyncGRPOTrainer. Here we swap its data_producer with
our NemoGymDataProducer that calls agent /run instead of vLLM generate.
"""
if not self._agent_servers:
LOG.warning(
"No NeMo Gym agent servers discovered. Multi-turn requires agent servers "
"started via ng_run with an agent config. Falling back to single-turn."
)
self._wire_single_turn(trainer, model_name, verify_timeout)
return
if not hasattr(trainer, "data_producer") or trainer.data_producer is None:
LOG.warning(
"Trainer has no data_producer. NeMo Gym multi-turn requires "
"use_data_producer=true (should be auto-set by plugin)."
)
return
from axolotl.core.trainers.grpo.async_trainer import AsyncDataProducer
from .data_producer import NemoGymDataProducer
# Get the current producer's config and params
current = trainer.data_producer
# Unwrap AsyncDataProducer to get the inner producer's config
if isinstance(current, AsyncDataProducer):
inner = current._inner
else:
inner = current
nemo_producer = NemoGymDataProducer(
config=inner.config,
prompt_dataset=inner._dataset,
num_generations=inner._num_generations,
generation_batch_size=inner._generation_batch_size,
train_batch_size=inner._train_batch_size,
steps_per_generation=inner._steps_per_generation,
shuffle_dataset=inner._shuffle_dataset,
seed=inner._seed,
agent_servers=self._agent_servers,
dataset_lookup=self._dataset_lookup or {},
request_timeout=float(cfg.nemo_gym_run_timeout or 300),
)
nemo_producer.set_trainer(trainer)
# Re-wrap in AsyncDataProducer if async prefetch is enabled
if getattr(trainer.args, "async_prefetch", False):
nemo_producer = AsyncDataProducer(
nemo_producer,
background_produce_kwargs={"skip_policy_logps": True},
)
trainer.data_producer = nemo_producer
LOG.info(
f"NeMo Gym data producer installed "
f"(agent servers: {list(self._agent_servers.keys())}, "
f"async={'yes' if getattr(trainer.args, 'async_prefetch', False) else 'no'})"
)
# Passthrough reward function — agent /run already computed rewards
from .rewards import reward_env
if hasattr(trainer, "reward_funcs"):
trainer.reward_funcs.append(reward_env)
trainer.reward_func_names.append("nemo_gym")
trainer.reward_processing_classes.append(None)
@staticmethod
def _check_lora_endpoint(vllm_gen):
"""Verify the vLLM server supports runtime LoRA loading."""
import requests as http_requests
if not hasattr(vllm_gen, "vllm_client") or vllm_gen.vllm_client is None:
return # Non-main rank in multi-GPU — client only exists on rank 0
base_url = vllm_gen.vllm_client.base_url
try:
# Send a dummy load request — if the endpoint exists, we get a
# proper error (400/404 about the adapter), not a route 404.
resp = http_requests.post(
f"{base_url}/v1/load_lora_adapter",
json={"lora_name": "__probe__", "lora_path": "/nonexistent"},
timeout=5,
)
if (
resp.status_code == 404
and "Not Found" in resp.text
and "adapter" not in resp.text.lower()
):
LOG.warning(
"vLLM server does not expose /v1/load_lora_adapter. "
"Set VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 when starting vLLM, e.g.:\n"
" VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 python -m vllm.entrypoints.openai.api_server "
"--enable-lora --max-lora-rank 64 ..."
)
except Exception:
pass # Server might not be up yet, sync will warn later
def _setup_lora_sync(self, trainer):
"""Replace sync_weights with LoRA adapter sync via filesystem + HTTP.
If the async trainer is detected (has ``_sync_lora_adapter``), delegates
to it — that method already handles multi-GPU (FSDP/DeepSpeed state_dict
gather, broadcast sync dir, barrier).
Otherwise installs a standalone closure for the non-async GRPO path that
saves the adapter and POSTs to ``/v1/load_lora_adapter``.
"""
vllm_gen = trainer.vllm_generation
# Async trainer path: delegate to its _sync_lora_adapter (multi-GPU safe)
if hasattr(trainer, "_sync_lora_adapter"):
def lora_sync_weights():
trainer._sync_lora_adapter()
vllm_gen.sync_weights = lora_sync_weights
type(vllm_gen).sync_weights = lambda self: lora_sync_weights()
LOG.info(
"Installed LoRA adapter sync "
"(delegates to async trainer._sync_lora_adapter)"
)
return
# Non-async standard GRPO path: standalone closure
import os
import shutil
import tempfile
import requests as http_requests
base_model = getattr(trainer.args, "model_name_or_path", None) or "axolotl-lora"
sync_state = {"version": 0, "sync_dir": tempfile.mkdtemp(prefix="lora_sync_")}
def lora_sync_weights():
"""Save LoRA adapter and load it into vLLM."""
accelerator = vllm_gen.accelerator
model = vllm_gen.model
if vllm_gen.mode != "server":
return
sync_state["version"] += 1
version = sync_state["version"]
adapter_path = os.path.join(sync_state["sync_dir"], f"v{version}")
wrapped_model = getattr(trainer, "model_wrapped", model)
state_dict = accelerator.get_state_dict(wrapped_model)
if accelerator.is_main_process:
unwrapped = accelerator.unwrap_model(model)
unwrapped.save_pretrained(adapter_path, state_dict=state_dict)
base_url = vllm_gen.vllm_client.base_url
resp = http_requests.post(
f"{base_url}/v1/load_lora_adapter",
json={
"lora_name": base_model,
"lora_path": adapter_path,
"load_inplace": True,
},
timeout=30,
)
if resp.status_code != 200:
resp = http_requests.post(
f"{base_url}/set_lora_adapter/",
json={
"lora_name": "active_lora",
"lora_int_id": version,
"lora_path": adapter_path,
},
timeout=30,
)
if resp.status_code != 200:
LOG.warning(
f"Failed to set LoRA adapter: "
f"{resp.status_code} {resp.text}"
)
return
try:
vllm_gen.vllm_client.reset_prefix_cache()
except Exception as exc:
LOG.warning("Failed to reset prefix cache: %s", exc)
if version > 1:
old = os.path.join(sync_state["sync_dir"], f"v{version - 1}")
if os.path.exists(old):
shutil.rmtree(old, ignore_errors=True)
LOG.info(f"Synced LoRA adapter v{version} to vLLM ({adapter_path})")
if accelerator.num_processes > 1:
import torch.distributed as dist
if dist.is_initialized():
dist.barrier()
vllm_gen.sync_weights = lora_sync_weights
type(vllm_gen).sync_weights = lambda self: lora_sync_weights()
LOG.info("Installed LoRA adapter sync (standalone fallback)")
def post_train_unload(self, cfg):
"""Cleanup NeMo Gym servers if we started them."""
if cfg.get("nemo_gym_enabled") and cfg.get("nemo_gym_auto_start", True):
from .server import _cleanup_servers
_cleanup_servers()

View File

@@ -0,0 +1,274 @@
# Copyright 2026 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
"""
NeMo Gym reward functions.
Provides ready-to-use reward functions for axolotl configs::
trl:
reward_funcs:
# Multi-turn: passthrough reward from agent /run
- axolotl.integrations.nemo_gym.rewards.reward_env
# Single-turn: call /verify endpoints directly
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
"""
from __future__ import annotations
from typing import Any
import numpy as np
import requests
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# ---------------------------------------------------------------------------
# Multi-turn passthrough reward
# ---------------------------------------------------------------------------
def reward_env(completions, prompts=None, **kwargs):
"""Passthrough: extract pre-computed reward from NeMo Gym agent /run response.
The ``NemoGymDataProducer`` injects ``env_reward`` into each sample's
kwargs after the agent returns from ``/run``. This function simply
forwards that value so TRL can log it alongside other reward signals.
Use this in your config when ``nemo_gym_multi_turn: true``::
trl:
reward_funcs:
- axolotl.integrations.nemo_gym.rewards.reward_env
"""
env_rewards = kwargs.get("env_reward")
if env_rewards is not None:
if isinstance(env_rewards, (list, tuple)):
return [float(r) for r in env_rewards]
return [float(env_rewards) for _ in completions]
return [0.0 for _ in completions]
# ---------------------------------------------------------------------------
# Single-turn /verify reward
# ---------------------------------------------------------------------------
# Module-level cache for discovered verify URLs
_verify_urls: dict[str, str] = {}
_verify_urls_lock = __import__("threading").Lock()
def _get_verify_urls(head_port: int = 11000) -> dict[str, str]:
"""Discover verify endpoints from the NeMo Gym head server.
Results are cached so that the HTTP round-trip only happens once per
process. A lock guards against concurrent discovery from multiple
threads (e.g. async_prefetch background thread + main training thread).
"""
global _verify_urls
if _verify_urls:
return _verify_urls
with _verify_urls_lock:
# Double-check after acquiring lock
if _verify_urls:
return _verify_urls
import yaml
try:
resp = requests.get(
f"http://127.0.0.1:{head_port}/global_config_dict_yaml", timeout=5
)
config = yaml.safe_load(resp.text)
if isinstance(config, str):
config = yaml.safe_load(config)
for _name, cfg in config.items():
if not isinstance(cfg, dict):
continue
for srv_name, srv_cfg in cfg.get("resources_servers", {}).items():
if (
isinstance(srv_cfg, dict)
and "host" in srv_cfg
and "port" in srv_cfg
):
_verify_urls[srv_name] = (
f"http://{srv_cfg['host']}:{srv_cfg['port']}/verify"
)
except Exception as exc:
LOG.warning(f"Failed to discover NeMo Gym verify endpoints: {exc}")
return _verify_urls
def reward_nemo_gym_verify(completions, prompts=None, **kwargs):
"""Call NeMo Gym ``/verify`` endpoint for each completion (single-turn).
Requires ``resources_server_ref`` and ``verify_extra`` kwargs, which the
NeMo Gym dataset loader injects automatically.
Use this in your config when ``nemo_gym_multi_turn: false``::
trl:
reward_funcs:
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
"""
verify_urls = _get_verify_urls()
refs = kwargs.get("resources_server_ref", [])
extras = kwargs.get("verify_extra", [])
scores = []
for i, completion in enumerate(completions):
text = completion[0]["content"] if completion else ""
prompt = prompts[i][0]["content"] if prompts and i < len(prompts) else ""
srv_name = (
refs[i]["name"] if i < len(refs) and isinstance(refs[i], dict) else ""
)
url = verify_urls.get(srv_name, "")
if not url:
scores.append(0.0)
continue
extra = extras[i] if i < len(extras) else {}
req = {k: v for k, v in extra.items() if v is not None}
req["responses_create_params"] = {
"input": [{"role": "user", "content": prompt}]
}
req["response"] = {
"id": "resp",
"created_at": 0,
"model": "axolotl",
"object": "response",
"output": [
{
"id": "msg",
"role": "assistant",
"type": "message",
"status": "completed",
"content": [
{"type": "output_text", "text": text, "annotations": []}
],
}
],
"parallel_tool_calls": True,
"tool_choice": "auto",
"tools": [],
}
try:
resp = requests.post(url, json=req, timeout=30)
reward = resp.json().get("reward", 0.0) if resp.ok else 0.0
except Exception as exc:
LOG.warning(f"Verify request to {url} failed: {exc}")
reward = 0.0
scores.append(float(reward))
return scores
# ---------------------------------------------------------------------------
# Factory used internally by the plugin
# ---------------------------------------------------------------------------
def create_nemo_gym_reward_fn(
global_config: dict,
verify_endpoints: dict[str, str],
model_name: str = "axolotl-model",
verify_timeout: int = 30,
):
"""Create a reward function bound to specific verify endpoints.
Used internally by ``NemoGymPlugin._wire_single_turn()`` to inject a
reward function that already knows the endpoint map (no discovery needed).
"""
def reward_fn(
completions: list[list[dict[str, str]]],
prompts: list[list[dict[str, str]]] | None = None,
**kwargs: Any,
) -> np.ndarray:
resources_server_refs = kwargs.get("resources_server_ref", [])
verify_extras = kwargs.get("verify_extra", [])
scores = []
for i, completion in enumerate(completions):
completion_text = completion[0]["content"]
task_prompt = prompts[i][0]["content"] if prompts else ""
server_name = (
resources_server_refs[i]["name"]
if i < len(resources_server_refs)
else None
)
if server_name is None or server_name not in verify_endpoints:
LOG.warning(
f"No verify endpoint for server '{server_name}', returning 0 reward"
)
scores.append(0.0)
continue
verify_endpoint = verify_endpoints[server_name]
verify_request = (
{k: v for k, v in verify_extras[i].items() if v is not None}
if i < len(verify_extras)
else {}
)
verify_request["responses_create_params"] = {
"input": [{"role": "user", "content": task_prompt}]
}
verify_request["response"] = {
"id": "resp",
"created_at": 0,
"model": model_name,
"object": "response",
"output": [
{
"id": "msg",
"role": "assistant",
"type": "message",
"status": "completed",
"content": [
{
"type": "output_text",
"text": completion_text,
"annotations": [],
}
],
}
],
"parallel_tool_calls": True,
"tool_choice": "auto",
"tools": [],
}
try:
resp = requests.post(
verify_endpoint, json=verify_request, timeout=verify_timeout
)
if resp.status_code == 200:
reward = resp.json().get("reward", 0.0)
else:
LOG.warning(
f"Verify request returned status {resp.status_code}: {resp.text[:200]}"
)
reward = 0.0
except requests.exceptions.RequestException as exc:
LOG.warning(f"Verify request failed: {exc}")
reward = 0.0
scores.append(float(reward))
return np.array(scores)
return reward_fn

View File

@@ -0,0 +1,242 @@
# Copyright 2026 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
"""
NeMo Gym server lifecycle management.
Handles cloning the NeMo Gym repo, starting resource servers,
waiting for readiness, and cleanup on exit.
"""
from __future__ import annotations
import atexit
import os
import subprocess # nosec B404
import time
import requests
import yaml
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
_ng_process = None
_ng_log_file = None
def ensure_gym_repo(gym_dir: str, auto_clone: bool = True) -> str:
"""Clone the NeMo Gym repo if it doesn't exist.
Args:
gym_dir: Path to the NeMo Gym directory.
auto_clone: Whether to auto-clone if missing.
Returns:
Resolved path to the NeMo Gym directory.
"""
gym_dir = os.path.expanduser(gym_dir)
if os.path.exists(gym_dir):
LOG.info(f"NeMo Gym directory exists at {gym_dir}")
return gym_dir
if not auto_clone:
raise FileNotFoundError(
f"NeMo Gym directory not found at {gym_dir} and auto_clone is disabled."
)
LOG.info(f"Cloning NeMo Gym to {gym_dir}...")
subprocess.run( # nosec
["git", "clone", "https://github.com/NVIDIA-NeMo/Gym.git", gym_dir],
check=True,
)
return gym_dir
def ensure_gym_venv(gym_dir: str):
"""Set up the NeMo Gym Python venv if not present."""
venv_python = os.path.join(gym_dir, ".venv", "bin", "python")
if os.path.exists(venv_python):
return
LOG.info("Setting up NeMo Gym venv...")
subprocess.run(["uv", "venv", "--python", "3.12"], cwd=gym_dir, check=True) # nosec
subprocess.run( # nosec
["bash", "-c", "source .venv/bin/activate && uv sync"],
cwd=gym_dir,
check=True,
)
def start_servers(
gym_dir: str,
config_paths: list[str],
head_port: int = 11000,
timeout: int = 360,
):
"""Start NeMo Gym resource servers via ng_run.
Args:
gym_dir: Path to the NeMo Gym directory.
config_paths: List of config YAML paths relative to gym_dir.
head_port: Port for the head server.
timeout: Max seconds to wait for servers.
"""
global _ng_process, _ng_log_file
head_url = f"http://127.0.0.1:{head_port}/global_config_dict_yaml"
# Check if already running
try:
requests.get(head_url, timeout=2)
LOG.info(f"NeMo Gym servers already running on port {head_port}.")
return
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
pass
ng_run_bin = os.path.join(gym_dir, ".venv", "bin", "ng_run")
config_arg = f"+config_paths=[{','.join(config_paths)}]"
_ng_log_file = open(os.path.join(gym_dir, "ng_run.log"), "w") # noqa: SIM115
_ng_process = subprocess.Popen( # nosec B603
[ng_run_bin, config_arg, "+skip_venv_if_present=true"],
cwd=gym_dir,
stdout=_ng_log_file,
stderr=subprocess.STDOUT,
)
atexit.register(_cleanup_servers)
LOG.info("Waiting for NeMo Gym head server...")
for _ in range(timeout // 3):
try:
requests.get(head_url, timeout=2)
LOG.info("NeMo Gym head server is ready.")
return
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
if _ng_process.poll() is not None:
raise RuntimeError(
"NeMo Gym server process exited unexpectedly. "
f"Check {gym_dir}/ng_run.log for details."
) from None
time.sleep(3)
raise RuntimeError(
f"NeMo Gym servers did not start within {timeout}s. "
f"Check {gym_dir}/ng_run.log for details."
)
def get_server_configs(head_port: int = 11000) -> dict:
"""Fetch the global config from the NeMo Gym head server.
Returns:
Dict mapping server_name -> server config.
"""
response = requests.get(
f"http://127.0.0.1:{head_port}/global_config_dict_yaml", timeout=5
)
response.raise_for_status()
result = yaml.safe_load(response.text)
# NeMo Gym head server double-encodes: YAML string inside a YAML string
if isinstance(result, str):
result = yaml.safe_load(result)
return result
def get_agent_servers(
global_config: dict, head_host: str = "127.0.0.1"
) -> dict[str, str]:
"""Discover NeMo Gym agent servers from the global config.
Agent servers handle multi-turn orchestration via /run endpoint.
Returns mapping of agent_name → URL (e.g., {"simple_agent": "http://host:port"}).
"""
agents = {}
for top_name, top_cfg in global_config.items():
if not isinstance(top_cfg, dict):
continue
agent_dict = top_cfg.get("responses_api_agents", {})
if not agent_dict:
continue
for _agent_name, agent_cfg in agent_dict.items():
if not isinstance(agent_cfg, dict):
continue
host = agent_cfg.get("host", "127.0.0.1")
port = agent_cfg.get("port")
if not port:
continue
# Replace loopback with head_host for remote access
host = _normalize_host(host, fallback=head_host)
# Use the top-level config name (not the inner agent name)
# because dataset agent_ref.name references the top-level name
agents[top_name] = f"http://{host}:{port}"
if agents:
LOG.info(f"Discovered NeMo Gym agent servers: {agents}")
return agents
def _normalize_host(host: str, fallback: str = "127.0.0.1") -> str:
"""Normalize bind-all and loopback addresses for reachability."""
if host in ("0.0.0.0", "localhost"): # nosec B104
return fallback
return host
def get_server_base_url(global_config: dict, server_name: str) -> str:
"""Get the base URL for a given resource server."""
try:
srv_cfg = global_config[server_name]["resources_servers"][server_name]
host = _normalize_host(srv_cfg["host"])
return f"http://{host}:{srv_cfg['port']}"
except (KeyError, TypeError) as exc:
raise ValueError(
f"Could not find resource server config for '{server_name}' in NeMo Gym. "
f"Available servers: {list(global_config.keys())}"
) from exc
def get_verify_endpoint(global_config: dict, server_name: str) -> str:
"""Get the /verify endpoint URL for a given resource server."""
return f"{get_server_base_url(global_config, server_name)}/verify"
def wait_for_resource_servers(global_config: dict, timeout: int = 180):
"""Wait for all resource servers in the config to become reachable."""
for srv_name in global_config:
try:
srv_cfg = global_config[srv_name]["resources_servers"][srv_name]
except (KeyError, TypeError):
continue # Skip non-server config entries silently
host, port = _normalize_host(srv_cfg["host"]), srv_cfg["port"]
LOG.info(f"Waiting for resource server '{srv_name}' at {host}:{port}...")
for _ in range(timeout // 2):
try:
requests.get(f"http://{host}:{port}/", timeout=2)
LOG.info(f"Resource server '{srv_name}' is ready.")
break
except requests.exceptions.ConnectionError:
time.sleep(2)
else:
raise RuntimeError(
f"Resource server '{srv_name}' at {host}:{port} "
f"did not start within {timeout}s."
)
def _cleanup_servers():
"""Terminate NeMo Gym server process on exit."""
global _ng_process, _ng_log_file
if _ng_process is not None and _ng_process.poll() is None:
LOG.info("Terminating NeMo Gym servers...")
_ng_process.terminate()
try:
_ng_process.wait(timeout=10)
except subprocess.TimeoutExpired:
_ng_process.kill()
if _ng_log_file is not None:
_ng_log_file.close()

View File

@@ -446,6 +446,114 @@ def main(script_args: ScriptArguments):
"logprob_token_ids": extract_logprobs(all_outputs)[1],
}
# --- OpenAI-compatible endpoints (for NeMo Gym agent integration) ---
@app.get("/v1/models")
async def list_models():
"""OpenAI-compatible models endpoint."""
return {
"object": "list",
"data": [
{"id": script_args.model, "object": "model", "owned_by": "axolotl"}
],
}
@app.post("/v1/chat/completions")
async def openai_chat_completions(request_body: dict):
"""OpenAI-compatible chat completions endpoint.
Translates OpenAI format to our internal /chat/ format so NeMo Gym's
model server proxy can call us directly.
"""
messages_list = request_body.get("messages", [])
temperature = request_body.get("temperature", 1.0)
max_tokens = request_body.get("max_tokens", 512)
top_p = request_body.get("top_p", 1.0)
n = request_body.get("n", 1)
generation_kwargs = {
"n": n,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"logprobs": 0, # Always return logprobs (NeMo Gym needs them)
}
sampling_params = SamplingParams(
**{k: v for k, v in generation_kwargs.items() if v is not None}
)
# Send to vLLM worker
chunked = chunk_list([messages_list], script_args.data_parallel_size)
for conn, chunk in zip(connections, chunked, strict=True):
if not chunk:
chunk = [[{"role": "user", "content": "<placeholder>"}]]
kwargs = {
"messages": chunk,
"sampling_params": sampling_params,
"use_tqdm": False,
"lora_request": active_lora["request"],
}
conn.send({"type": "call", "method": "chat", "kwargs": kwargs})
all_outputs = [conn.recv() for conn in connections]
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
all_outputs = list(chain.from_iterable(all_outputs))
if not all_outputs:
return {"choices": [], "model": script_args.model}
# Format as OpenAI response
import uuid
choices = []
for i, output in enumerate(all_outputs):
for j, out in enumerate(output.outputs):
text = out.text
# Extract token IDs if requested
# Build logprobs in OpenAI format
lp_list = None
if out.logprobs:
lp_list = {
"content": [
{"token": "", "logprob": next(iter(lp.values())).logprob} # nosec B105
for lp in out.logprobs
]
}
choice = {
"index": i * n + j,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop"
if out.finish_reason == "stop"
else "length",
"logprobs": lp_list,
}
# Include token ID information for NeMo Gym
choice["prompt_token_ids"] = output.prompt_token_ids
choice["generation_token_ids"] = list(out.token_ids)
if out.logprobs:
choice["generation_log_probs"] = [
next(iter(lp.values())).logprob for lp in out.logprobs
]
choices.append(choice)
prompt_tokens = len(all_outputs[0].prompt_token_ids) if all_outputs else 0
completion_tokens = sum(
len(out.token_ids) for o in all_outputs for out in o.outputs
)
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion",
"model": script_args.model,
"choices": choices,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
@app.post("/init_communicator/")