Nemo gym integration (#3516) [skip ci]
* nemo gym integration with grpo wip * mostly working * cleanup * simplify * update docs * nemo gym support wip * cleanup * chore: lint * address PR review and add more tests * chore: lint * post merge lora fixes for CI (#3536) [skip ci] * post merge lora fixes for CI * handle lora kernel auto-enable for moe without grouped_mm * prefer not to import torch in schema validation * address pr comments, add timeout, add tests * roundup_power2_divisions not needed with newer pytorch versions (#3540) * roundup_power2_divisions not needed with newer pytorch versions * remove typo * update qwen3.5 moe 35b-a3b yaml for 5090 * more bug fixes * fix tests to match updated trainer * don't use fa2 for hooks test * reset plugins on the instance * retry download * fix references to renamed axolotl_cfg property on trainer * Fix ref to trainer cfg * fix: robust handling of race condition on patching check (#3543) [skip ci] * EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci] * EBFT wip * fixes * more fixeS * add missing strided module * ebft fixes for multi-turn * make ebft work with async * add example for ebft w qwen3.5 * fix for split thinking and update yaml for lora over linear attention only * enforce_eager for vllm arg in schema * fix sync weights * fix multi-gpu * handle updated sig for mm * ddp fixes * improve multi-gpu handling, don't calculate logits, adaptive completion length * chore: lint * chore: lint * support completion_mean * Address corereview feedback * clamp min IS ratio * Address PR code review * more fixes identified * address code review * Fix property from rebase conflict * fix for ebft sync and update docs * make trainer loss patch check a solo test --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -176,6 +176,22 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
if not async_grpo:
|
||||
# Filter out async/fast-async-only fields not in standard GRPOConfig.
|
||||
# These are defined in FastAsyncGRPOConfig and only used by
|
||||
# AxolotlAsyncGRPOConfig. Standard GRPOConfig rejects them.
|
||||
import dataclasses
|
||||
|
||||
from trl import GRPOConfig as _BaseGRPOConfig
|
||||
|
||||
from axolotl.core.trainers.grpo.fast_async_trainer import (
|
||||
FastAsyncGRPOConfig,
|
||||
)
|
||||
|
||||
async_only_fields = {
|
||||
f.name for f in dataclasses.fields(FastAsyncGRPOConfig)
|
||||
} - {f.name for f in dataclasses.fields(_BaseGRPOConfig)}
|
||||
blocklist_args_kwargs.extend(list(async_only_fields))
|
||||
if self.cfg.rl is RLType.GDPO:
|
||||
training_args_kwargs.setdefault(
|
||||
"multi_objective_aggregation", "normalize_then_sum"
|
||||
|
||||
@@ -34,7 +34,16 @@ class EBFTStrategy:
|
||||
return AxolotlStridedEBFTTrainer
|
||||
|
||||
# Structured mode: async or sync
|
||||
use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False)
|
||||
# use_data_producer also triggers async trainer (needed for LoRA sync
|
||||
# without async_prefetch, since sync trainer lacks LoRA sync support)
|
||||
use_async = (
|
||||
cfg
|
||||
and cfg.trl
|
||||
and (
|
||||
getattr(cfg.trl, "async_prefetch", False)
|
||||
or getattr(cfg.trl, "use_data_producer", False)
|
||||
)
|
||||
)
|
||||
if use_async:
|
||||
from axolotl.core.trainers.ebft.trainer import AxolotlAsyncEBFTTrainer
|
||||
|
||||
@@ -50,7 +59,14 @@ class EBFTStrategy:
|
||||
return AxolotlStridedEBFTConfig
|
||||
|
||||
# Structured mode: async or sync config
|
||||
use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False)
|
||||
use_async = (
|
||||
cfg
|
||||
and cfg.trl
|
||||
and (
|
||||
getattr(cfg.trl, "async_prefetch", False)
|
||||
or getattr(cfg.trl, "use_data_producer", False)
|
||||
)
|
||||
)
|
||||
if use_async:
|
||||
return AxolotlAsyncEBFTConfig
|
||||
return AxolotlEBFTConfig
|
||||
|
||||
@@ -1012,27 +1012,44 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
import requests
|
||||
|
||||
vllm_client = self.vllm_generation.vllm_client
|
||||
url = f"{vllm_client.base_url}/set_lora_adapter/"
|
||||
base_url = vllm_client.base_url
|
||||
base_model = getattr(self.args, "model_name_or_path", "axolotl-lora")
|
||||
sync_timeout = getattr(self.args, "vllm_server_timeout", 300) or 300
|
||||
|
||||
# Try standard vLLM /v1/load_lora_adapter first, fall back to custom endpoint
|
||||
response = requests.post(
|
||||
url,
|
||||
f"{base_url}/v1/load_lora_adapter",
|
||||
json={
|
||||
"lora_name": "active_lora",
|
||||
"lora_int_id": self._lora_sync_version,
|
||||
"lora_name": base_model,
|
||||
"lora_path": adapter_path,
|
||||
"load_inplace": True,
|
||||
},
|
||||
timeout=sync_timeout,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
"Failed to set LoRA adapter: %s %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
# Fallback: try custom /set_lora_adapter/ endpoint
|
||||
response = requests.post(
|
||||
f"{base_url}/set_lora_adapter/",
|
||||
json={
|
||||
"lora_name": "active_lora",
|
||||
"lora_int_id": self._lora_sync_version,
|
||||
"lora_path": adapter_path,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
return
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
"Failed to set LoRA adapter: %s %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
return
|
||||
|
||||
# Reset prefix cache after adapter update
|
||||
vllm_client.reset_prefix_cache()
|
||||
try:
|
||||
vllm_client.reset_prefix_cache()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to reset prefix cache: %s", exc)
|
||||
|
||||
# Clean up old adapter versions (keep only current)
|
||||
if self._lora_sync_version > 1:
|
||||
@@ -2486,6 +2503,9 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
logits, completion_ids, self.temperature
|
||||
)
|
||||
all_logps.append(logps)
|
||||
# Liger fused path doesn't compute entropy — append zeros
|
||||
if compute_entropy:
|
||||
all_entropies.append(torch.zeros_like(logps))
|
||||
else:
|
||||
logits = logits[:, :-1, :]
|
||||
logits = logits[:, -logits_to_keep:, :]
|
||||
|
||||
412
src/axolotl/integrations/nemo_gym/README.md
Normal file
412
src/axolotl/integrations/nemo_gym/README.md
Normal file
@@ -0,0 +1,412 @@
|
||||
# NeMo Gym Integration for Axolotl
|
||||
|
||||
Train LLMs with reinforcement learning using [NVIDIA NeMo Gym](https://github.com/NVIDIA-NeMo/Gym) environments as reward sources. NeMo Gym provides 50+ verified RL environments spanning math, coding, tool-use, reasoning, and safety — each with deterministic reward signals.
|
||||
|
||||
## Validated Training Paths
|
||||
|
||||
| Path | Speed | Multi-turn | Architecture |
|
||||
|------|-------|------------|--------------|
|
||||
| **Async GRPO + Data Producer** | Fastest (3x) | Yes | `NemoGymDataProducer` replaces vLLM generation |
|
||||
| Standard GRPO + Data Producer | Baseline | Yes | Same producer, no async prefetch |
|
||||
| Standard GRPO + /verify | Simplest | No | Reward function calls /verify directly |
|
||||
| FSDP2 + /verify (2 GPU) | Distributed | No | `fsdp_version: 2` |
|
||||
|
||||
Multi-turn uses `nemo_gym_multi_turn: true` which auto-enables the async trainer's
|
||||
data producer protocol. The plugin's `NemoGymDataProducer` calls NeMo Gym agent `/run`
|
||||
endpoints and returns `RolloutDataset` with proper IS correction, env_mask, and rewards.
|
||||
|
||||
All paths tested end-to-end with Qwen3-0.6B + LoRA, logged to wandb project `nemo-gym-rl`.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- [uv](https://github.com/astral-sh/uv) package manager (for NeMo Gym's venv)
|
||||
- Two GPUs recommended (one for vLLM server, one for training)
|
||||
|
||||
### 1. Set Up NeMo Gym
|
||||
|
||||
```bash
|
||||
git clone https://github.com/NVIDIA-NeMo/Gym.git ~/Gym
|
||||
cd ~/Gym
|
||||
uv venv --python 3.12 && source .venv/bin/activate && uv sync
|
||||
|
||||
# Fix pycosat build (GCC 13+)
|
||||
CFLAGS="" uv pip install pycosat --python .venv/bin/python --no-build-isolation
|
||||
|
||||
# Pre-build resource server venvs
|
||||
for dir in resources_servers/reasoning_gym resources_servers/example_single_tool_call responses_api_models/vllm_model responses_api_agents/simple_agent; do
|
||||
uv venv --seed --allow-existing --python 3.12 $dir/.venv
|
||||
CFLAGS="" uv pip install --python $dir/.venv/bin/python pycosat --no-build-isolation 2>/dev/null
|
||||
uv pip install --python $dir/.venv/bin/python -e . "ray[default]==2.52.1"
|
||||
done
|
||||
|
||||
# Install extra deps for reasoning_gym
|
||||
uv pip install --python resources_servers/reasoning_gym/.venv/bin/python \
|
||||
reasoning-gym matplotlib pillow cycler contourpy kiwisolver
|
||||
```
|
||||
|
||||
### 2. Multi-Turn with Async GRPO (Recommended — Fastest Path)
|
||||
|
||||
This is the fully validated, highest-performance path. NeMo Gym's agent server handles
|
||||
multi-turn tool execution while axolotl's async GRPO prefetches data in background threads.
|
||||
|
||||
**Step 1: Create the NeMo Gym agent config**
|
||||
|
||||
Create `~/Gym/configs/axolotl_tool_calling.yaml`:
|
||||
```yaml
|
||||
# Resource server (tools + verify)
|
||||
example_single_tool_call:
|
||||
resources_servers:
|
||||
example_single_tool_call:
|
||||
entrypoint: app.py
|
||||
domain: agent
|
||||
verified: false
|
||||
|
||||
# Model server proxy (forwards to your vLLM)
|
||||
policy_model:
|
||||
responses_api_models:
|
||||
vllm_model:
|
||||
entrypoint: app.py
|
||||
base_url: http://localhost:8000/v1
|
||||
api_key: dummy_key
|
||||
model: Qwen/Qwen3-0.6B # Must match your training model
|
||||
return_token_id_information: true
|
||||
uses_reasoning_parser: false
|
||||
|
||||
# Agent server (orchestrates multi-turn via /run)
|
||||
example_single_tool_call_simple_agent:
|
||||
responses_api_agents:
|
||||
simple_agent:
|
||||
entrypoint: app.py
|
||||
resources_server:
|
||||
type: resources_servers
|
||||
name: example_single_tool_call
|
||||
model_server:
|
||||
type: responses_api_models
|
||||
name: policy_model
|
||||
datasets:
|
||||
- name: weather
|
||||
type: example
|
||||
jsonl_fpath: resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
|
||||
```
|
||||
|
||||
**Step 2: Start three services**
|
||||
|
||||
```bash
|
||||
# Terminal 1: vLLM OpenAI server on GPU 0
|
||||
CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen3-0.6B --max-model-len 2048 --gpu-memory-utilization 0.85
|
||||
|
||||
# Terminal 2: NeMo Gym (resource server + model proxy + agent)
|
||||
cd ~/Gym && .venv/bin/ng_run \
|
||||
"+config_paths=[configs/axolotl_tool_calling.yaml]" "+skip_venv_if_present=true"
|
||||
|
||||
# Terminal 3: Training on GPU 1
|
||||
cd experiments && CUDA_VISIBLE_DEVICES=1 CUDA_HOME=$HOME/env-claude-cu130/cuda_shim \
|
||||
axolotl train nemo_gym_async_agent.yaml
|
||||
```
|
||||
|
||||
**Step 3: Training config** (`nemo_gym_async_agent.yaml`):
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3-0.6B
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
|
||||
sequence_len: 2048
|
||||
|
||||
rl: grpo
|
||||
chat_template: tokenizer_default
|
||||
|
||||
trl:
|
||||
use_vllm: true
|
||||
vllm_mode: server
|
||||
vllm_server_host: localhost
|
||||
vllm_server_port: 8000
|
||||
vllm_lora_sync: true
|
||||
vllm_sync_interval: 5
|
||||
# Async GRPO — 3x faster than standard
|
||||
use_data_producer: true
|
||||
async_prefetch: true
|
||||
num_generations: 4
|
||||
max_completion_length: 512
|
||||
temperature: 0.8
|
||||
reward_funcs:
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_env
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.nemo_gym.NemoGymPlugin
|
||||
|
||||
nemo_gym_enabled: true
|
||||
nemo_gym_auto_start: false
|
||||
nemo_gym_head_port: 11000
|
||||
nemo_gym_multi_turn: true
|
||||
nemo_gym_verify_timeout: 120
|
||||
nemo_gym_datasets:
|
||||
- path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
|
||||
server_name: example_single_tool_call
|
||||
|
||||
datasets:
|
||||
- path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
|
||||
type: chat_template
|
||||
field_messages: responses_create_params.input
|
||||
message_field_content: content
|
||||
message_field_role: role
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.85
|
||||
max_model_len: 2048
|
||||
tensor_parallel_size: 1
|
||||
|
||||
learning_rate: 5e-6
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
max_steps: 30
|
||||
gradient_checkpointing: true
|
||||
bf16: true
|
||||
output_dir: ./outputs/nemo_gym_async
|
||||
|
||||
use_wandb: true
|
||||
wandb_project: nemo-gym-rl
|
||||
```
|
||||
|
||||
### 3. Single-Turn Training (Simplest — No Agent Server Needed)
|
||||
|
||||
For environments that only need single-turn verify (math, coding challenges), you don't need
|
||||
an agent server. The plugin's reward function calls `/verify` directly.
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-0.5B-Instruct
|
||||
rl: grpo
|
||||
chat_template: tokenizer_default
|
||||
|
||||
trl:
|
||||
use_vllm: true
|
||||
vllm_mode: colocate
|
||||
vllm_enable_sleep_mode: false
|
||||
num_generations: 8
|
||||
max_completion_length: 128
|
||||
temperature: 0.9
|
||||
reward_funcs:
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.nemo_gym.NemoGymPlugin
|
||||
|
||||
nemo_gym_enabled: true
|
||||
nemo_gym_auto_start: false
|
||||
nemo_gym_head_port: 11000
|
||||
nemo_gym_datasets:
|
||||
- path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
|
||||
server_name: reasoning_gym
|
||||
|
||||
datasets:
|
||||
- path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
|
||||
type: chat_template
|
||||
field_messages: responses_create_params.input
|
||||
message_field_content: content
|
||||
message_field_role: role
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.3
|
||||
max_model_len: 512
|
||||
tensor_parallel_size: 1
|
||||
|
||||
learning_rate: 1e-5
|
||||
micro_batch_size: 4
|
||||
gradient_accumulation_steps: 2
|
||||
max_steps: 50
|
||||
output_dir: ./outputs/nemo_gym_arithmetic
|
||||
```
|
||||
|
||||
Only needs `ng_run` with resource servers (no agent config):
|
||||
```bash
|
||||
cd ~/Gym && ng_run "+config_paths=[resources_servers/reasoning_gym/configs/resources_only.yaml]" "+skip_venv_if_present=true"
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Single-Turn
|
||||
```text
|
||||
axolotl train → GRPO Trainer generates completions
|
||||
→ NeMo Gym plugin reward_fn calls POST /verify on resource server
|
||||
→ reward flows back to GRPO for advantage computation
|
||||
```
|
||||
|
||||
### Multi-Turn (Agent /run)
|
||||
```text
|
||||
┌─────────────┐ ┌──────────────┐ ┌──────────────────┐
|
||||
│ axolotl │ │ NeMo Gym │────▶│ vLLM OpenAI │
|
||||
│ train │────▶│ Agent /run │◀────│ Server (GPU 0) │
|
||||
│ (GPU 1) │ │ │ │ /v1/completions │
|
||||
└─────────────┘ └──────┬───────┘ └──────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────┐
|
||||
│ Resource │
|
||||
│ Server │
|
||||
│ (tools + │
|
||||
│ verify) │
|
||||
└─────────────┘
|
||||
```
|
||||
|
||||
The agent server orchestrates the entire multi-turn loop:
|
||||
1. Calls our vLLM server for model generation
|
||||
2. Parses tool calls from model output
|
||||
3. Executes tools against resource servers
|
||||
4. Feeds tool results back to the model
|
||||
5. Repeats until done, then calls /verify for reward
|
||||
6. Returns token IDs + logprobs + reward to our rollout_func
|
||||
|
||||
### Data Producer Architecture (Multi-Turn)
|
||||
|
||||
When `nemo_gym_multi_turn: true`, the plugin automatically forces `use_data_producer: true`
|
||||
which selects the `AxolotlAsyncGRPOTrainer`. The plugin then swaps the trainer's data
|
||||
producer with `NemoGymDataProducer`, which:
|
||||
|
||||
1. Gets a prompt batch from the dataset iterator
|
||||
2. Expands by `num_generations` (one agent call per rollout)
|
||||
3. Calls NeMo Gym agents via async HTTP (`aiohttp.gather`)
|
||||
4. Parses responses into padded tensors (`RolloutDataset`)
|
||||
5. Returns with `_pending_policy_logps=True` for deferred scoring
|
||||
|
||||
The main thread then runs `_compute_deferred_scores()` which:
|
||||
- Computes **policy logprobs** on the training model (GPU forward pass)
|
||||
- Computes **IS correction** using agent's sampling logprobs vs training model logprobs
|
||||
- Computes advantages with group-level normalization
|
||||
- All downstream features work: replay buffer, re-roll, streaming, zero-adv skip
|
||||
|
||||
With `async_prefetch: true`, the data producer runs in a background thread — giving ~3x
|
||||
speedup as generation and training overlap. With `async_prefetch: false`, it runs
|
||||
synchronously on the main thread (still uses the data producer protocol).
|
||||
|
||||
### Weight Sync (LoRA Mode)
|
||||
|
||||
With `vllm_lora_sync: true`, the plugin (or async trainer) replaces NCCL-based weight
|
||||
sync with filesystem + HTTP:
|
||||
|
||||
1. `accelerator.get_state_dict()` gathers LoRA weights from all ranks
|
||||
2. Rank 0 saves adapter to `/tmp/lora_sync_*/vN/`
|
||||
3. Rank 0 POSTs to `/set_lora_adapter/` on vLLM server
|
||||
4. vLLM loads adapter natively via Punica kernels
|
||||
5. Only ~40MB transferred (vs multiple GBs for full model weights)
|
||||
|
||||
### Multi-Environment Support
|
||||
|
||||
Datasets support per-row environment routing via `agent_ref`:
|
||||
```jsonl
|
||||
{"agent_ref": {"name": "reasoning_gym"}, "responses_create_params": {...}}
|
||||
{"agent_ref": {"name": "instruction_following"}, "responses_create_params": {...}}
|
||||
```
|
||||
|
||||
Or use the simpler per-dataset routing:
|
||||
```yaml
|
||||
nemo_gym_datasets:
|
||||
- path: reasoning_data.jsonl
|
||||
server_name: reasoning_gym
|
||||
- path: tool_data.jsonl
|
||||
server_name: example_single_tool_call
|
||||
```
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `nemo_gym_enabled` | bool | `null` | Enable the NeMo Gym integration |
|
||||
| `nemo_gym_dir` | str | `~/Gym` | Path to NeMo Gym repo |
|
||||
| `nemo_gym_auto_clone` | bool | `true` | Auto-clone NeMo Gym repo if missing |
|
||||
| `nemo_gym_auto_start` | bool | `true` | Auto-start resource servers |
|
||||
| `nemo_gym_config_paths` | list[str] | — | Server config YAMLs (relative to gym_dir) |
|
||||
| `nemo_gym_datasets` | list[dict] | required | Dataset configs with `path` and optional `server_name` |
|
||||
| `nemo_gym_head_port` | int | `11000` | Head server port |
|
||||
| `nemo_gym_server_timeout` | int | `360` | Server startup timeout (seconds) |
|
||||
| `nemo_gym_verify_timeout` | int | `30` | Per-request timeout (seconds) |
|
||||
| `nemo_gym_multi_turn` | bool | `false` | Enable multi-turn via agent /run |
|
||||
|
||||
### Dataset JSONL Format
|
||||
|
||||
Each line must have `responses_create_params` with `input` messages:
|
||||
```json
|
||||
{
|
||||
"responses_create_params": {
|
||||
"input": [{"role": "user", "content": "What's the weather in SF?"}],
|
||||
"tools": [{"name": "get_weather", "type": "function", "strict": true, "parameters": {...}}]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
For multi-turn agent routing, include `agent_ref`:
|
||||
```json
|
||||
{"agent_ref": {"name": "my_agent"}, "responses_create_params": {...}}
|
||||
```
|
||||
|
||||
Note: Tool definitions MUST include `"strict": true` and `"additionalProperties": false` for NeMo Gym agent compatibility.
|
||||
|
||||
### Reward Functions
|
||||
|
||||
The plugin provides two built-in reward functions — no user code needed:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reward_funcs:
|
||||
# Multi-turn (nemo_gym_multi_turn: true):
|
||||
# Passthrough — agent /run already computed the reward
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_env
|
||||
|
||||
# Single-turn (nemo_gym_multi_turn: false):
|
||||
# Calls /verify endpoints on NeMo Gym resource servers
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
|
||||
```
|
||||
|
||||
Both are also importable from Python:
|
||||
|
||||
```python
|
||||
from axolotl.integrations.nemo_gym import reward_env, reward_nemo_gym_verify
|
||||
```
|
||||
|
||||
## Known Issues / Troubleshooting
|
||||
|
||||
### NeMo Gym Server Setup
|
||||
- **pycosat build failure**: `CFLAGS="" uv pip install pycosat --no-build-isolation`
|
||||
- **Ray version mismatch**: Pin `ray[default]==2.52.1` in all server venvs
|
||||
- **Pre-build venvs**: `ng_run` creates per-server venvs via Ray. Pre-build them and use `+skip_venv_if_present=true`
|
||||
- **Tool `strict` field required**: Agent server validates tool definitions require `strict: true`
|
||||
|
||||
### vLLM / Weight Sync
|
||||
- **Start vLLM with LoRA + tool calling + runtime loading**:
|
||||
```bash
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 \
|
||||
CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen3-4B-Instruct-2507 \
|
||||
--max-model-len 4096 \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--enable-lora --max-lora-rank 64 \
|
||||
--enable-auto-tool-choice --tool-call-parser hermes
|
||||
```
|
||||
- **`VLLM_ALLOW_RUNTIME_LORA_UPDATING=1`**: Required for `vllm_lora_sync: true`. Without it, vLLM won't expose the `/v1/load_lora_adapter` endpoint and weight sync will fail silently. The plugin warns if this endpoint is missing.
|
||||
- **`--enable-lora`**: Enables LoRA adapter support in vLLM
|
||||
- **`--enable-auto-tool-choice --tool-call-parser hermes`**: Required for Qwen3 tool calling
|
||||
- **`max_model_len` must be > `max_completion_length`**: Leave room for prompt tokens (~200). If equal, the NeMo Gym model proxy gets a 400 error and returns empty completions.
|
||||
- **`CUDA_HOME` required**: DeepSpeed import needs it for the nvcc shim
|
||||
- **NCCL weight sync broken with vLLM 0.17**: Use `vllm_lora_sync: true` (filesystem + HTTP via `/v1/load_lora_adapter`)
|
||||
|
||||
### Multi-Turn
|
||||
- **Agent server required**: Multi-turn delegates to NeMo Gym's agent server `/run` endpoint. Without an agent, the plugin falls back to single-turn `/verify`
|
||||
- **Model server proxy**: NeMo Gym needs a `responses_api_models` server that proxies to your vLLM. See the agent config example above
|
||||
|
||||
### FSDP2
|
||||
- Validated on 2 GPUs with single-turn + LoRA
|
||||
- Async field filtering: The builder automatically filters async-only config fields when using the standard GRPO trainer
|
||||
|
||||
## Comparison with Other Integrations
|
||||
|
||||
| Feature | Axolotl + NeMo Gym | Unsloth + NeMo Gym | NeMo RL (native) |
|
||||
|---------|-------------------|-------------------|-------------------|
|
||||
| Server management | Automatic | Manual (notebook) | Built-in |
|
||||
| Multi-environment | Per-row routing | Manual code | YAML config |
|
||||
| Multi-turn / tool use | Agent /run delegation | No | Agent /run (Ray) |
|
||||
| Async GRPO (3x speedup) | Yes | No | Yes |
|
||||
| LoRA sync | Filesystem + HTTP | N/A | NCCL |
|
||||
| Multi-GPU (FSDP2) | Yes | No | Yes (Ray) |
|
||||
| Config-driven | Yes | No (code) | Yes |
|
||||
25
src/axolotl/integrations/nemo_gym/__init__.py
Normal file
25
src/axolotl/integrations/nemo_gym/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright 2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
"""
|
||||
Plugin for NVIDIA NeMo Gym integration with Axolotl.
|
||||
|
||||
NeMo Gym provides RL training environments for LLMs with verification-based
|
||||
reward signals. This plugin manages the NeMo Gym server lifecycle, loads
|
||||
datasets in the NeMo Gym JSONL format, and creates reward functions that
|
||||
call the NeMo Gym /verify endpoints.
|
||||
"""
|
||||
|
||||
from .args import NemoGymArgs
|
||||
from .plugin import NemoGymPlugin
|
||||
from .rewards import reward_env, reward_nemo_gym_verify
|
||||
|
||||
__all__ = [
|
||||
"NemoGymArgs",
|
||||
"NemoGymPlugin",
|
||||
"reward_env",
|
||||
"reward_nemo_gym_verify",
|
||||
]
|
||||
146
src/axolotl/integrations/nemo_gym/args.py
Normal file
146
src/axolotl/integrations/nemo_gym/args.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Copyright 2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
"""
|
||||
Input arguments for the NeMo Gym integration plugin.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class NemoGymArgs(BaseModel):
|
||||
"""Configuration args for the NeMo Gym integration."""
|
||||
|
||||
nemo_gym_enabled: bool | None = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Enable NeMo Gym integration for environment-based RL rewards."
|
||||
},
|
||||
)
|
||||
nemo_gym_dir: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Path to the NeMo Gym repository clone. "
|
||||
"If not set and nemo_gym_auto_clone is True, clones to ~/Gym."
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_auto_clone: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Automatically clone the NeMo Gym repository if not present. "
|
||||
"Defaults to True when nemo_gym_enabled is set."
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_config_paths: list[str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"List of NeMo Gym resource server config YAML paths, relative to nemo_gym_dir. "
|
||||
"Example: ['resources_servers/reasoning_gym/configs/resources_only.yaml']"
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_head_port: int | None = Field(
|
||||
default=11000,
|
||||
json_schema_extra={
|
||||
"description": "Port for the NeMo Gym head server. Defaults to 11000."
|
||||
},
|
||||
)
|
||||
nemo_gym_server_timeout: int | None = Field(
|
||||
default=360,
|
||||
json_schema_extra={
|
||||
"description": "Timeout in seconds waiting for NeMo Gym servers to start. Defaults to 360."
|
||||
},
|
||||
)
|
||||
nemo_gym_verify_timeout: int | None = Field(
|
||||
default=30,
|
||||
json_schema_extra={
|
||||
"description": "Timeout in seconds for individual /verify requests. Defaults to 30."
|
||||
},
|
||||
)
|
||||
nemo_gym_run_timeout: int | None = Field(
|
||||
default=300,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Timeout in seconds for each agent /run request (one multi-turn rollout). "
|
||||
"Prevents stuck generations (e.g. model looping on <think> tags) from "
|
||||
"blocking training indefinitely. Defaults to 300 (5 minutes)."
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_datasets: list[dict] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"List of NeMo Gym dataset configs. Each entry has 'path' (JSONL file path "
|
||||
"relative to nemo_gym_dir) and optionally 'server_name' (default resource server). "
|
||||
"If the JSONL rows have agent_ref.name, that takes precedence per row, "
|
||||
"enabling multi-environment training from a single dataset file. "
|
||||
"Optional 'max_samples' to limit per dataset."
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_auto_start: bool | None = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Automatically start NeMo Gym resource servers. Defaults to True. "
|
||||
"Set to False if servers are already running externally."
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_model_name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Model name to report in verify requests. "
|
||||
"Defaults to the base_model from the main config."
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_multi_turn: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Enable multi-turn rollouts via NeMo Gym. When True, uses TRL's "
|
||||
"rollout_func to run multi-step interactions with tool execution. "
|
||||
"Requires use_vllm=True in TRL config. The model generates responses, "
|
||||
"tool calls are executed against resource servers, and results are "
|
||||
"fed back for the next turn. Final reward comes from /verify."
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_max_turns: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Maximum number of turns per multi-turn rollout. Defaults to 10. "
|
||||
"Each turn consists of a model generation + optional tool execution."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_nemo_gym_config(cls, data):
|
||||
if data.get("nemo_gym_enabled"):
|
||||
if not data.get("nemo_gym_config_paths") and data.get(
|
||||
"nemo_gym_auto_start", True
|
||||
):
|
||||
raise ValueError(
|
||||
"nemo_gym_config_paths is required when nemo_gym_enabled=True "
|
||||
"and nemo_gym_auto_start is not False."
|
||||
)
|
||||
if not data.get("nemo_gym_datasets"):
|
||||
raise ValueError(
|
||||
"nemo_gym_datasets is required when nemo_gym_enabled=True. "
|
||||
"Provide at least one dataset with 'path' and 'server_name'."
|
||||
)
|
||||
return data
|
||||
226
src/axolotl/integrations/nemo_gym/data_producer.py
Normal file
226
src/axolotl/integrations/nemo_gym/data_producer.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# Copyright 2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
"""
|
||||
NeMo Gym Data Producer for async GRPO training.
|
||||
|
||||
Replaces GRPODataProducer to generate rollouts via NeMo Gym agent /run endpoints
|
||||
instead of vLLM. The agent handles generation, tool execution, and reward computation.
|
||||
Returns RolloutDataset in the same format as the standard producer, so all downstream
|
||||
components (deferred scoring, IS correction, streaming, replay, re-roll) work unchanged.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from trl.trainer.utils import pad
|
||||
|
||||
from axolotl.core.trainers.grpo.async_trainer import GRPODataProducer, RolloutDataset
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .multi_turn import _call_agents, _parse_agent_response
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class NemoGymDataProducer(GRPODataProducer):
|
||||
"""Produces GRPO rollouts by calling NeMo Gym agent /run endpoints.
|
||||
|
||||
Drop-in replacement for GRPODataProducer. Instead of calling vLLM for generation,
|
||||
sends prompts to NeMo Gym agents which handle generation + tool execution + reward.
|
||||
Returns the same RolloutDataset format so deferred scoring, IS correction,
|
||||
replay buffer, and re-roll all work unchanged.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
agent_servers: dict[str, str],
|
||||
dataset_lookup: dict,
|
||||
request_timeout: float = 10800,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._agent_servers = agent_servers
|
||||
self._dataset_lookup = dataset_lookup
|
||||
self._request_timeout = request_timeout
|
||||
|
||||
def produce(
|
||||
self,
|
||||
model: Any,
|
||||
global_step: int,
|
||||
*,
|
||||
skip_policy_logps: bool = False,
|
||||
processing_class: Any = None,
|
||||
accelerator: Any = None,
|
||||
args: Any = None,
|
||||
_rank0_only: bool = False,
|
||||
**kwargs,
|
||||
) -> RolloutDataset | None:
|
||||
"""Generate rollouts via NeMo Gym agents.
|
||||
|
||||
Calls agent /run endpoints, parses responses into padded tensors,
|
||||
and returns a RolloutDataset for deferred scoring on the main thread.
|
||||
"""
|
||||
trainer = self._trainer
|
||||
is_main = trainer.accelerator.is_main_process
|
||||
device = trainer.accelerator.device
|
||||
|
||||
if _rank0_only and not is_main:
|
||||
return None
|
||||
|
||||
# Get prompt batch from iterator
|
||||
try:
|
||||
inputs = next(self._prompt_iter)
|
||||
except StopIteration:
|
||||
self._prompt_iter = iter(self._prompt_dl)
|
||||
inputs = next(self._prompt_iter)
|
||||
|
||||
# Extract dataset items for agent calls
|
||||
dataset_items = []
|
||||
for inp in inputs:
|
||||
prompt_text = ""
|
||||
prompt = inp.get("prompt", [])
|
||||
if isinstance(prompt, list) and prompt:
|
||||
prompt_text = (
|
||||
prompt[-1].get("content", "")
|
||||
if isinstance(prompt[-1], dict)
|
||||
else str(prompt[-1])
|
||||
)
|
||||
elif isinstance(prompt, str):
|
||||
prompt_text = prompt
|
||||
|
||||
# Find the full dataset item, preserving agent_ref for routing
|
||||
full_item = self._dataset_lookup.get(prompt_text, {})
|
||||
item = full_item.get("verify_extra", {})
|
||||
if not item:
|
||||
item = {
|
||||
"responses_create_params": {
|
||||
"input": [{"role": "user", "content": prompt_text}]
|
||||
}
|
||||
}
|
||||
# Preserve agent_ref from the dataset row for _call_agents routing
|
||||
if "agent_ref" in full_item and "agent_ref" not in item:
|
||||
item["agent_ref"] = full_item["agent_ref"]
|
||||
dataset_items.append(item)
|
||||
|
||||
# Expand by num_generations (agent produces one rollout per call)
|
||||
expanded_items = []
|
||||
for item in dataset_items:
|
||||
for _ in range(self._num_generations):
|
||||
expanded_items.append(item)
|
||||
|
||||
# Call NeMo Gym agents
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
responses = loop.run_until_complete(
|
||||
_call_agents(
|
||||
dataset_items=expanded_items,
|
||||
agent_servers=self._agent_servers,
|
||||
timeout=self._request_timeout,
|
||||
max_completion_length=trainer.max_completion_length,
|
||||
temperature=trainer.temperature,
|
||||
top_p=getattr(trainer, "top_p", None) or 0.999,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Parse responses
|
||||
eos_token_id = trainer.processing_class.eos_token_id
|
||||
prompt_ids_list = []
|
||||
completion_ids_list = []
|
||||
env_mask_list = []
|
||||
logprobs_list = []
|
||||
rewards_list = []
|
||||
|
||||
for resp in responses:
|
||||
parsed = _parse_agent_response(resp, eos_token_id)
|
||||
prompt_ids_list.append(parsed["prompt_ids"])
|
||||
completion_ids_list.append(parsed["completion_ids"])
|
||||
env_mask_list.append(parsed["env_mask"])
|
||||
logprobs_list.append(parsed["logprobs"])
|
||||
rewards_list.append(parsed["reward"])
|
||||
|
||||
# Pad to tensors
|
||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
|
||||
prompt_ids = pad(
|
||||
prompt_ids, padding_value=trainer.pad_token_id, padding_side="left"
|
||||
)
|
||||
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
|
||||
|
||||
completion_ids = [
|
||||
torch.tensor(ids, device=device) for ids in completion_ids_list
|
||||
]
|
||||
completion_mask = [
|
||||
torch.ones_like(ids, dtype=torch.long) for ids in completion_ids
|
||||
]
|
||||
completion_ids = pad(
|
||||
completion_ids, padding_value=trainer.pad_token_id, padding_side="right"
|
||||
)
|
||||
completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
|
||||
|
||||
# Sampling logprobs from agent (used for IS correction)
|
||||
sampling_logps = [
|
||||
torch.tensor(lp, dtype=torch.float32, device=device) for lp in logprobs_list
|
||||
]
|
||||
sampling_per_token_logps = pad(
|
||||
sampling_logps, padding_value=0.0, padding_side="right"
|
||||
)
|
||||
|
||||
# env_mask as tool_mask (1=model tokens, 0=tool tokens)
|
||||
tool_mask = [torch.tensor(m, device=device) for m in env_mask_list]
|
||||
tool_mask = pad(tool_mask, padding_value=1, padding_side="right")
|
||||
|
||||
# Inject rewards into inputs so _compute_deferred_scores can use them
|
||||
# The deferred scoring path calls _calculate_rewards which reads reward_funcs.
|
||||
# Our passthrough reward_fn reads "env_reward" from kwargs.
|
||||
for i, inp in enumerate(inputs):
|
||||
# Each input gets rewards for its num_generations rollouts
|
||||
start = i * self._num_generations
|
||||
end = start + self._num_generations
|
||||
inp["env_reward"] = rewards_list[start:end]
|
||||
|
||||
# Expand inputs to match expanded rollouts (num_generations copies)
|
||||
expanded_inputs = []
|
||||
for inp in inputs:
|
||||
for g in range(self._num_generations):
|
||||
expanded_inp = dict(inp)
|
||||
expanded_inp["env_reward"] = inp["env_reward"][g]
|
||||
expanded_inputs.append(expanded_inp)
|
||||
|
||||
# Decode completions for reward functions
|
||||
completions = trainer.processing_class.batch_decode(
|
||||
completion_ids, skip_special_tokens=True
|
||||
)
|
||||
|
||||
# Build total token count
|
||||
num_items_in_batch = completion_mask.sum()
|
||||
|
||||
# Build output dict (same shape as _generate_only)
|
||||
output = {
|
||||
"prompt_ids": prompt_ids,
|
||||
"prompt_mask": prompt_mask,
|
||||
"completion_ids": completion_ids,
|
||||
"completion_mask": completion_mask,
|
||||
"num_items_in_batch": num_items_in_batch,
|
||||
"advantages": torch.zeros(completion_ids.size(0), device=device),
|
||||
"sampling_per_token_logps": sampling_per_token_logps,
|
||||
"tool_mask": tool_mask,
|
||||
# Deferred scoring markers
|
||||
"_pending_policy_logps": True,
|
||||
"_deferred_inputs": expanded_inputs,
|
||||
"_deferred_prompts": [inp.get("prompt", "") for inp in expanded_inputs],
|
||||
"_deferred_completions": completions,
|
||||
"_deferred_completion_ids_list": completion_ids_list,
|
||||
"_rank0_only": _rank0_only,
|
||||
}
|
||||
|
||||
return RolloutDataset(output)
|
||||
135
src/axolotl/integrations/nemo_gym/dataset.py
Normal file
135
src/axolotl/integrations/nemo_gym/dataset.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Copyright 2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
"""
|
||||
Dataset loading for NeMo Gym JSONL files.
|
||||
|
||||
Converts NeMo Gym JSONL format into HuggingFace Datasets compatible
|
||||
with TRL's GRPOTrainer. Supports multi-environment routing via:
|
||||
1. Per-dataset server_name (all rows in a file go to one server)
|
||||
2. Per-row agent_ref.name (each row specifies its own server)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def load_nemo_gym_datasets(
|
||||
gym_dir: str,
|
||||
dataset_configs: list[dict],
|
||||
) -> Dataset:
|
||||
"""Load and merge NeMo Gym JSONL datasets with multi-environment support.
|
||||
|
||||
Each dataset config should have:
|
||||
- path: JSONL file path (absolute, or relative to gym_dir)
|
||||
- server_name: Default NeMo Gym server for this dataset.
|
||||
Can be overridden per-row if the JSONL has an "agent_ref" field.
|
||||
- max_samples (optional): Max number of samples to use from this dataset
|
||||
|
||||
Per-row routing: If a JSONL row has an "agent_ref": {"name": "..."} field,
|
||||
that takes precedence over the dataset-level server_name. This allows mixing
|
||||
environments within a single dataset file (matching TRL's pattern).
|
||||
|
||||
The output dataset has columns:
|
||||
- prompt: list[dict] chat format
|
||||
- resources_server_ref: dict with {"name": server_name}
|
||||
- verify_extra: dict with original JSONL data for verify requests
|
||||
|
||||
Args:
|
||||
gym_dir: Path to the NeMo Gym directory.
|
||||
dataset_configs: List of dataset configuration dicts.
|
||||
|
||||
Returns:
|
||||
A HuggingFace Dataset ready for GRPOTrainer.
|
||||
"""
|
||||
all_examples = []
|
||||
|
||||
for ds_cfg in dataset_configs:
|
||||
path = ds_cfg["path"]
|
||||
default_server = ds_cfg.get("server_name", "")
|
||||
max_samples = ds_cfg.get("max_samples")
|
||||
|
||||
# Resolve path
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(gym_dir, path)
|
||||
path = os.path.expanduser(path)
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(
|
||||
f"NeMo Gym dataset not found at {path}. "
|
||||
"Ensure the dataset file exists or run the appropriate "
|
||||
"NeMo Gym dataset creation script."
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
f"Loading NeMo Gym dataset from {path} (default server: {default_server})"
|
||||
)
|
||||
|
||||
with open(path, encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
if max_samples and len(lines) > max_samples:
|
||||
lines = random.sample(lines, max_samples) # nosec B311
|
||||
|
||||
for line in lines:
|
||||
data = json.loads(line)
|
||||
|
||||
# Extract user prompt from the input messages
|
||||
inputs = data.get("responses_create_params", {}).get("input", [])
|
||||
task_prompt = ""
|
||||
for inp in inputs:
|
||||
if isinstance(inp, dict) and inp.get("role") in ("user",):
|
||||
task_prompt = inp.get("content", "")
|
||||
break
|
||||
if not task_prompt and inputs:
|
||||
# Fallback: use the last input's content
|
||||
task_prompt = (
|
||||
inputs[-1].get("content", "")
|
||||
if isinstance(inputs[-1], dict)
|
||||
else ""
|
||||
)
|
||||
|
||||
# Per-row agent routing: agent_ref.name can override dataset-level server_name.
|
||||
# NeMo Gym datasets may use agent names (e.g., "reasoning_gym_simple_agent")
|
||||
# which differ from resource server names (e.g., "reasoning_gym").
|
||||
# The dataset-level server_name is always the fallback.
|
||||
row_agent_ref = data.get("agent_ref", {})
|
||||
server_name = default_server
|
||||
if row_agent_ref and row_agent_ref.get("name"):
|
||||
# Use per-row name, but only if it looks like a resource server name.
|
||||
# Agent names typically have "_simple_agent" or "_agent" suffix.
|
||||
row_name = row_agent_ref["name"]
|
||||
if row_agent_ref.get("type") != "responses_api_agents":
|
||||
# Not an agent — could be a direct resource server reference
|
||||
server_name = row_name
|
||||
|
||||
all_examples.append(
|
||||
{
|
||||
"prompt": [{"role": "user", "content": task_prompt}],
|
||||
"resources_server_ref": {"name": server_name},
|
||||
"verify_extra": data,
|
||||
}
|
||||
)
|
||||
|
||||
random.shuffle(all_examples)
|
||||
|
||||
# Log environment distribution
|
||||
env_counts: dict[str, int] = {}
|
||||
for ex in all_examples:
|
||||
name = ex["resources_server_ref"]["name"]
|
||||
env_counts[name] = env_counts.get(name, 0) + 1
|
||||
LOG.info(f"Loaded {len(all_examples)} NeMo Gym examples: {env_counts}")
|
||||
|
||||
return Dataset.from_list(all_examples)
|
||||
@@ -0,0 +1,64 @@
|
||||
# Axolotl + NeMo Gym: Multi-Environment RL Training Example
|
||||
#
|
||||
# Trains on multiple NeMo Gym environments simultaneously:
|
||||
# - Mini Sudoku (reasoning_gym)
|
||||
# - Instruction Following
|
||||
#
|
||||
# Prerequisites:
|
||||
# - uv (https://github.com/astral-sh/uv) installed
|
||||
# - Download instruction_following.jsonl from nvidia/Nemotron-RL-instruction_following
|
||||
# and place it at ~/Gym/resources_servers/instruction_following/data/instruction_following.jsonl
|
||||
#
|
||||
# Usage:
|
||||
# axolotl train examples/nemo_gym_multi_env.yaml
|
||||
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
model_type: AutoModelForCausalLM
|
||||
|
||||
sequence_len: 4096
|
||||
load_in_4bit: false
|
||||
|
||||
# RL configuration
|
||||
rl: grpo
|
||||
chat_template: tokenizer_default
|
||||
|
||||
# GRPO / TRL settings
|
||||
trl:
|
||||
use_vllm: false
|
||||
num_generations: 8
|
||||
max_completion_length: 2048
|
||||
|
||||
# NeMo Gym plugin
|
||||
plugins:
|
||||
- axolotl.integrations.nemo_gym.NemoGymPlugin
|
||||
|
||||
nemo_gym_enabled: true
|
||||
nemo_gym_dir: ~/Gym
|
||||
nemo_gym_auto_clone: true
|
||||
nemo_gym_auto_start: true
|
||||
nemo_gym_config_paths:
|
||||
- resources_servers/reasoning_gym/configs/resources_only.yaml
|
||||
- resources_servers/instruction_following/configs/instruction_following.yaml
|
||||
nemo_gym_datasets:
|
||||
- path: resources_servers/reasoning_gym/data/train_mini_sudoku.jsonl
|
||||
server_name: reasoning_gym
|
||||
max_samples: 1000
|
||||
- path: resources_servers/instruction_following/data/instruction_following.jsonl
|
||||
server_name: instruction_following
|
||||
max_samples: 1000
|
||||
|
||||
# Training hyperparameters
|
||||
learning_rate: 1.0e-5
|
||||
weight_decay: 0.001
|
||||
lr_scheduler: linear
|
||||
warmup_ratio: 0.0
|
||||
optimizer: adamw_8bit
|
||||
|
||||
num_epochs: 1
|
||||
max_steps: 100
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 64
|
||||
|
||||
logging_steps: 1
|
||||
save_steps: 100
|
||||
output_dir: ./outputs/nemo_gym_multi_env
|
||||
@@ -0,0 +1,79 @@
|
||||
# Axolotl + NeMo Gym: Multi-Turn Tool-Use RL Training Example
|
||||
#
|
||||
# This config trains a model with multi-turn reinforcement learning
|
||||
# using NeMo Gym's agentic environments. The model generates responses,
|
||||
# executes tool calls against NeMo Gym resource servers, and receives
|
||||
# environment feedback across multiple turns.
|
||||
#
|
||||
# Multi-turn mode uses TRL's rollout_func with env_mask to ensure only
|
||||
# model-generated tokens contribute to the training loss.
|
||||
#
|
||||
# Prerequisites:
|
||||
# - uv (https://github.com/astral-sh/uv) installed
|
||||
# - vLLM must be enabled (required for generate_rollout_completions)
|
||||
#
|
||||
# Usage:
|
||||
# # Terminal 1: Start vLLM server
|
||||
# trl vllm-serve --model Qwen/Qwen2.5-1.5B-Instruct
|
||||
#
|
||||
# # Terminal 2: Train
|
||||
# axolotl train examples/nemo_gym_multi_turn.yaml
|
||||
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
model_type: AutoModelForCausalLM
|
||||
|
||||
sequence_len: 4096
|
||||
load_in_4bit: false
|
||||
|
||||
# RL configuration
|
||||
rl: grpo
|
||||
chat_template: tokenizer_default
|
||||
|
||||
# GRPO / TRL settings — vLLM required for multi-turn
|
||||
trl:
|
||||
use_vllm: true
|
||||
vllm_mode: server
|
||||
vllm_server_host: localhost
|
||||
vllm_server_port: 8000
|
||||
num_generations: 4
|
||||
max_completion_length: 2048
|
||||
|
||||
# NeMo Gym plugin with multi-turn enabled
|
||||
plugins:
|
||||
- axolotl.integrations.nemo_gym.NemoGymPlugin
|
||||
|
||||
nemo_gym_enabled: true
|
||||
nemo_gym_dir: ~/Gym
|
||||
nemo_gym_auto_clone: true
|
||||
nemo_gym_auto_start: true
|
||||
nemo_gym_head_port: 11000
|
||||
nemo_gym_server_timeout: 360
|
||||
nemo_gym_verify_timeout: 30
|
||||
|
||||
# Multi-turn settings
|
||||
nemo_gym_multi_turn: true
|
||||
nemo_gym_max_turns: 10
|
||||
|
||||
# Resource server configs — use environments that support tool calls
|
||||
nemo_gym_config_paths:
|
||||
- resources_servers/reasoning_gym/configs/resources_only.yaml
|
||||
nemo_gym_datasets:
|
||||
- path: resources_servers/reasoning_gym/data/train_mini_sudoku.jsonl
|
||||
server_name: reasoning_gym
|
||||
max_samples: 2000
|
||||
|
||||
# Training hyperparameters
|
||||
learning_rate: 1.0e-5
|
||||
weight_decay: 0.001
|
||||
lr_scheduler: linear
|
||||
warmup_ratio: 0.0
|
||||
optimizer: adamw_8bit
|
||||
|
||||
num_epochs: 1
|
||||
max_steps: 100
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 64
|
||||
|
||||
logging_steps: 1
|
||||
save_steps: 100
|
||||
output_dir: ./outputs/nemo_gym_multi_turn
|
||||
@@ -0,0 +1,62 @@
|
||||
# Axolotl + NeMo Gym: Sudoku RL Training Example
|
||||
#
|
||||
# This config trains a model to solve 4x4 Mini Sudoku puzzles using GRPO
|
||||
# with NeMo Gym's reasoning_gym environment as the reward signal.
|
||||
#
|
||||
# Prerequisites:
|
||||
# - uv (https://github.com/astral-sh/uv) installed
|
||||
# - Git installed
|
||||
# - The NeMo Gym repo will be auto-cloned to ~/Gym on first run
|
||||
#
|
||||
# Usage:
|
||||
# axolotl train examples/nemo_gym_sudoku.yaml
|
||||
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
model_type: AutoModelForCausalLM
|
||||
|
||||
sequence_len: 4096
|
||||
load_in_4bit: false
|
||||
|
||||
# RL configuration
|
||||
rl: grpo
|
||||
chat_template: tokenizer_default
|
||||
|
||||
# GRPO / TRL settings
|
||||
trl:
|
||||
use_vllm: false
|
||||
num_generations: 8
|
||||
max_completion_length: 2048
|
||||
|
||||
# NeMo Gym plugin
|
||||
plugins:
|
||||
- axolotl.integrations.nemo_gym.NemoGymPlugin
|
||||
|
||||
nemo_gym_enabled: true
|
||||
nemo_gym_dir: ~/Gym
|
||||
nemo_gym_auto_clone: true
|
||||
nemo_gym_auto_start: true
|
||||
nemo_gym_head_port: 11000
|
||||
nemo_gym_server_timeout: 360
|
||||
nemo_gym_verify_timeout: 30
|
||||
nemo_gym_config_paths:
|
||||
- resources_servers/reasoning_gym/configs/resources_only.yaml
|
||||
nemo_gym_datasets:
|
||||
- path: resources_servers/reasoning_gym/data/train_mini_sudoku.jsonl
|
||||
server_name: reasoning_gym
|
||||
max_samples: 2000
|
||||
|
||||
# Training hyperparameters
|
||||
learning_rate: 1.0e-5
|
||||
weight_decay: 0.001
|
||||
lr_scheduler: linear
|
||||
warmup_ratio: 0.0
|
||||
optimizer: adamw_8bit
|
||||
|
||||
num_epochs: 1
|
||||
max_steps: 100
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 64
|
||||
|
||||
logging_steps: 1
|
||||
save_steps: 100
|
||||
output_dir: ./outputs/nemo_gym_sudoku
|
||||
329
src/axolotl/integrations/nemo_gym/multi_turn.py
Normal file
329
src/axolotl/integrations/nemo_gym/multi_turn.py
Normal file
@@ -0,0 +1,329 @@
|
||||
# Copyright 2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
"""
|
||||
Multi-turn rollout function for NeMo Gym environments.
|
||||
|
||||
Delegates multi-turn orchestration to NeMo Gym's agent servers via the /run
|
||||
endpoint. The agent handles generation (by calling our vLLM server), tool
|
||||
execution, session management, and reward computation.
|
||||
|
||||
This follows the same pattern as TRL's reference implementation at
|
||||
examples/scripts/nemo_gym/train_multi_environment.py.
|
||||
|
||||
Architecture:
|
||||
rollout_func(prompts, trainer)
|
||||
-> expand prompts by num_generations
|
||||
-> async POST /run to agent servers (one per sample)
|
||||
-> parse response: prompt_ids, completion_ids, logprobs, env_mask, reward
|
||||
-> return to TRL for GRPO training
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def create_nemo_gym_rollout_func(
|
||||
agent_servers: dict[str, str],
|
||||
dataset_lookup: dict[int, dict],
|
||||
request_timeout: float = 10800,
|
||||
):
|
||||
"""Create a TRL-compatible rollout_func that delegates to NeMo Gym agents.
|
||||
|
||||
Args:
|
||||
agent_servers: Mapping of agent_name → agent URL (e.g., {"simple_agent": "http://host:port"}).
|
||||
dataset_lookup: Mapping of dataset index → full JSONL row dict.
|
||||
request_timeout: HTTP timeout for /run requests.
|
||||
|
||||
Returns:
|
||||
A rollout_func with signature (prompts: list[str], trainer) -> dict.
|
||||
"""
|
||||
|
||||
def rollout_func(prompts: list[str], trainer) -> dict[str, Any]:
|
||||
is_training = trainer.model.training
|
||||
num_generations = (
|
||||
trainer.num_generations
|
||||
if is_training
|
||||
else getattr(trainer, "num_generations_eval", 1)
|
||||
)
|
||||
temperature = trainer.temperature
|
||||
top_p = getattr(trainer, "top_p", None) or 0.999
|
||||
max_completion_length = trainer.max_completion_length
|
||||
eos_token_id = trainer.processing_class.eos_token_id
|
||||
|
||||
# Expand prompts: each prompt index repeated num_generations times
|
||||
expanded_items = []
|
||||
expanded_prompt_indices = []
|
||||
for prompt_str in prompts:
|
||||
# Prompts from TRL are chat-templated strings. Find the dataset item
|
||||
# by matching against dataset_lookup keys (raw user message text).
|
||||
full_item = None
|
||||
for key, val in dataset_lookup.items():
|
||||
if isinstance(key, str) and prompt_str == key:
|
||||
full_item = val
|
||||
break
|
||||
|
||||
if full_item is None:
|
||||
full_item = {
|
||||
"responses_create_params": {
|
||||
"input": [{"role": "user", "content": prompt_str}]
|
||||
}
|
||||
}
|
||||
|
||||
for _ in range(num_generations):
|
||||
# Preserve agent_ref for routing in _call_agents
|
||||
dispatched: dict = full_item.get("verify_extra", full_item) # type: ignore[assignment]
|
||||
if isinstance(dispatched, dict) and "agent_ref" not in dispatched:
|
||||
agent_ref = full_item.get("agent_ref")
|
||||
if agent_ref:
|
||||
dispatched = {**dispatched, "agent_ref": agent_ref}
|
||||
expanded_items.append(dispatched)
|
||||
expanded_prompt_indices.append(prompt_str)
|
||||
|
||||
# Call NeMo Gym agents
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
responses = loop.run_until_complete(
|
||||
_call_agents(
|
||||
dataset_items=expanded_items,
|
||||
agent_servers=agent_servers,
|
||||
timeout=request_timeout,
|
||||
max_completion_length=max_completion_length,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Parse responses into rollout format
|
||||
all_prompt_ids = []
|
||||
all_completion_ids = []
|
||||
all_env_masks = []
|
||||
all_logprobs = []
|
||||
all_rewards = []
|
||||
all_num_turns = []
|
||||
|
||||
for _i, response in enumerate(responses):
|
||||
result = _parse_agent_response(response, eos_token_id)
|
||||
all_prompt_ids.append(result["prompt_ids"])
|
||||
all_completion_ids.append(result["completion_ids"])
|
||||
all_env_masks.append(result["env_mask"])
|
||||
all_logprobs.append(result["logprobs"])
|
||||
all_rewards.append(result["reward"])
|
||||
all_num_turns.append(result["num_turns"])
|
||||
|
||||
# TRL expects prompt_ids to be unique (one per original prompt, not per generation)
|
||||
unique_prompt_ids = all_prompt_ids[::num_generations]
|
||||
|
||||
# Wrap logprobs for TRL: list[list[list[float]]]
|
||||
def _normalize(lp):
|
||||
while isinstance(lp, (list, tuple)) and len(lp) > 0:
|
||||
lp = lp[0]
|
||||
return float(lp) if lp is not None else 0.0
|
||||
|
||||
wrapped_logprobs = [[[_normalize(lp)] for lp in seq] for seq in all_logprobs]
|
||||
|
||||
return {
|
||||
"prompt_ids": unique_prompt_ids,
|
||||
"completion_ids": all_completion_ids,
|
||||
"env_mask": all_env_masks,
|
||||
"logprobs": wrapped_logprobs,
|
||||
"logprob_token_ids": None, # nosec B105
|
||||
"env_reward": all_rewards,
|
||||
"num_turns": all_num_turns,
|
||||
}
|
||||
|
||||
return rollout_func
|
||||
|
||||
|
||||
async def _call_agents(
|
||||
dataset_items: list[dict],
|
||||
agent_servers: dict[str, str],
|
||||
timeout: float,
|
||||
max_completion_length: int = 4096,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 0.999,
|
||||
) -> list[dict]:
|
||||
"""Async batch POST to NeMo Gym agent /run endpoints."""
|
||||
import aiohttp
|
||||
|
||||
results = []
|
||||
connector = aiohttp.TCPConnector(limit_per_host=64, limit=256)
|
||||
# Use sock_read for per-request timeout (not total session timeout).
|
||||
# This ensures a single stuck generation doesn't block all other requests.
|
||||
client_timeout = aiohttp.ClientTimeout(total=None, sock_read=timeout)
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
connector=connector, timeout=client_timeout, cookie_jar=aiohttp.DummyCookieJar()
|
||||
) as session:
|
||||
tasks = []
|
||||
for item in dataset_items:
|
||||
agent_ref = item.get("agent_ref", {})
|
||||
agent_name = agent_ref.get("name", "")
|
||||
agent_url = agent_servers.get(agent_name, "")
|
||||
|
||||
if not agent_url:
|
||||
# Fallback: try first available agent
|
||||
if agent_servers:
|
||||
agent_url = next(iter(agent_servers.values()))
|
||||
else:
|
||||
results.append(
|
||||
{
|
||||
"response": {"output": []},
|
||||
"reward": 0.0,
|
||||
"error": "No agent server",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Build request body
|
||||
request_body = dict(item)
|
||||
params = request_body.setdefault("responses_create_params", {})
|
||||
params.setdefault("max_output_tokens", max_completion_length)
|
||||
params["temperature"] = temperature
|
||||
params["top_p"] = top_p
|
||||
|
||||
tasks.append(_post_run(session, agent_url, request_body))
|
||||
|
||||
if tasks:
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for resp in responses:
|
||||
if isinstance(resp, BaseException):
|
||||
LOG.warning(f"Agent /run failed: {resp}")
|
||||
results.append(
|
||||
{"response": {"output": []}, "reward": 0.0, "error": str(resp)}
|
||||
)
|
||||
else:
|
||||
results.append(resp)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def _post_run(session, agent_url: str, body: dict) -> dict:
|
||||
"""POST to agent /run endpoint."""
|
||||
async with session.post(f"{agent_url}/run", json=body) as resp:
|
||||
if resp.status == 200:
|
||||
return await resp.json()
|
||||
text = await resp.text()
|
||||
return {
|
||||
"response": {"output": []},
|
||||
"reward": 0.0,
|
||||
"error": f"HTTP {resp.status}: {text[:200]}",
|
||||
}
|
||||
|
||||
|
||||
def _parse_agent_response(response: dict, eos_token_id: int) -> dict:
|
||||
"""Parse NeMo Gym agent /run response into rollout format.
|
||||
|
||||
The agent returns:
|
||||
response.output[]: list of turns, each with prompt_token_ids,
|
||||
generation_token_ids, generation_log_probs
|
||||
reward: float
|
||||
"""
|
||||
# Defaults for failed/empty responses
|
||||
defaults = {
|
||||
"prompt_ids": [eos_token_id],
|
||||
"completion_ids": [eos_token_id],
|
||||
"env_mask": [0],
|
||||
"logprobs": [0.0],
|
||||
"reward": 0.0,
|
||||
"num_turns": 0,
|
||||
}
|
||||
|
||||
if not isinstance(response, dict) or "error" in response:
|
||||
return defaults
|
||||
|
||||
output_items = response.get("response", {}).get("output", [])
|
||||
reward = float(response.get("reward", 0.0))
|
||||
|
||||
if not output_items:
|
||||
defaults["reward"] = reward
|
||||
return defaults
|
||||
|
||||
# Check at least one valid output
|
||||
has_valid = False
|
||||
for item in output_items:
|
||||
if item.get("type") == "function_call":
|
||||
has_valid = True
|
||||
break
|
||||
if item.get("type") == "message":
|
||||
for c in item.get("content", []):
|
||||
if (
|
||||
isinstance(c, dict)
|
||||
and c.get("type") == "output_text"
|
||||
and c.get("text", "").strip()
|
||||
):
|
||||
has_valid = True
|
||||
break
|
||||
if has_valid:
|
||||
break
|
||||
|
||||
if not has_valid:
|
||||
defaults["reward"] = reward
|
||||
return defaults
|
||||
|
||||
# Extract multi-turn token sequences
|
||||
first_prompt_ids = None
|
||||
completion_ids = []
|
||||
env_mask = []
|
||||
logprobs = []
|
||||
seen_token_ids = []
|
||||
num_turns = 0
|
||||
|
||||
for item in output_items:
|
||||
prompt_token_ids = item.get("prompt_token_ids", [])
|
||||
generation_token_ids = item.get("generation_token_ids", [])
|
||||
generation_log_probs = item.get("generation_log_probs", [])
|
||||
|
||||
if not generation_token_ids:
|
||||
continue
|
||||
|
||||
num_turns += 1
|
||||
|
||||
# First turn: capture prompt
|
||||
if first_prompt_ids is None:
|
||||
first_prompt_ids = list(prompt_token_ids)
|
||||
seen_token_ids = list(prompt_token_ids)
|
||||
else:
|
||||
# Subsequent turns: extract tool result tokens (between turns)
|
||||
if len(prompt_token_ids) > len(seen_token_ids):
|
||||
tool_result_tokens = prompt_token_ids[len(seen_token_ids) :]
|
||||
# Tool result tokens are NOT trained on (env_mask = 0)
|
||||
completion_ids.extend(tool_result_tokens)
|
||||
env_mask.extend([0] * len(tool_result_tokens))
|
||||
logprobs.extend([0.0] * len(tool_result_tokens))
|
||||
|
||||
# Add generation tokens (trained on, env_mask = 1)
|
||||
completion_ids.extend(generation_token_ids)
|
||||
env_mask.extend([1] * len(generation_token_ids))
|
||||
|
||||
# Pad logprobs if shorter than generation tokens
|
||||
gen_logprobs = list(generation_log_probs) if generation_log_probs else []
|
||||
if len(gen_logprobs) < len(generation_token_ids):
|
||||
gen_logprobs.extend([0.0] * (len(generation_token_ids) - len(gen_logprobs)))
|
||||
logprobs.extend(gen_logprobs[: len(generation_token_ids)])
|
||||
|
||||
# Update seen tokens
|
||||
seen_token_ids = list(prompt_token_ids) + list(generation_token_ids)
|
||||
|
||||
if first_prompt_ids is None:
|
||||
return defaults
|
||||
|
||||
return {
|
||||
"prompt_ids": first_prompt_ids,
|
||||
"completion_ids": completion_ids,
|
||||
"env_mask": env_mask,
|
||||
"logprobs": logprobs,
|
||||
"reward": reward,
|
||||
"num_turns": num_turns,
|
||||
}
|
||||
503
src/axolotl/integrations/nemo_gym/plugin.py
Normal file
503
src/axolotl/integrations/nemo_gym/plugin.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# Copyright 2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
"""
|
||||
NeMo Gym Plugin for Axolotl.
|
||||
|
||||
Integrates NVIDIA NeMo Gym environments as reward sources for GRPO training.
|
||||
Handles server lifecycle, dataset loading, and reward function wiring.
|
||||
|
||||
Supports two modes:
|
||||
- Single-turn (default): reward_fn calls /verify after each generation
|
||||
- Multi-turn (nemo_gym_multi_turn: true): rollout_func orchestrates
|
||||
multi-step interactions with tool execution via resource servers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class NemoGymPlugin(BasePlugin):
|
||||
"""Plugin for NVIDIA NeMo Gym integration with Axolotl.
|
||||
|
||||
When enabled, this plugin:
|
||||
1. Clones and sets up the NeMo Gym repo (if needed)
|
||||
2. Starts NeMo Gym resource servers
|
||||
3. Loads datasets from NeMo Gym JSONL files
|
||||
4. For single-turn: creates a reward function calling /verify
|
||||
5. For multi-turn: creates a rollout_func with tool execution and env_mask
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._gym_dir = None
|
||||
self._global_config = None
|
||||
self._verify_endpoints = None
|
||||
self._server_base_urls = None
|
||||
self._reward_fn = None
|
||||
self._dataset_lookup = None
|
||||
self._agent_servers = {}
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.nemo_gym.NemoGymArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""Apply monkeypatches before trainer creation."""
|
||||
if not cfg.nemo_gym_enabled:
|
||||
return
|
||||
|
||||
# Always skip NCCL communicator init in NeMo Gym mode.
|
||||
# NeMo Gym uses its own vLLM server (standard OpenAI API), not the TRL
|
||||
# colocate/NCCL path. The NCCL init fails with vLLM V1 and standard servers.
|
||||
trl_cfg = getattr(cfg, "trl", None)
|
||||
if trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server":
|
||||
self._patch_skip_nccl_init()
|
||||
|
||||
def _patch_skip_nccl_init(self):
|
||||
"""Monkeypatch VLLMClient.init_communicator to no-op.
|
||||
|
||||
NeMo Gym uses its own vLLM server (standard OpenAI API or custom LoRA
|
||||
serve script). The NCCL communicator is not needed and fails with both
|
||||
vLLM V1 engine and standard OpenAI server mode.
|
||||
"""
|
||||
try:
|
||||
from trl.generation.vllm_client import VLLMClient
|
||||
|
||||
VLLMClient._original_init_communicator = VLLMClient.init_communicator
|
||||
VLLMClient.init_communicator = lambda self, **kwargs: LOG.info(
|
||||
"Skipping NCCL init_communicator (LoRA sync mode)"
|
||||
)
|
||||
LOG.info("Patched VLLMClient.init_communicator to no-op for LoRA sync")
|
||||
except Exception as exc:
|
||||
LOG.warning(f"Failed to patch VLLMClient: {exc}")
|
||||
|
||||
def register(self, cfg):
|
||||
if not cfg.get("nemo_gym_enabled"):
|
||||
return
|
||||
|
||||
LOG.info("NeMo Gym integration enabled")
|
||||
gym_dir = cfg.get("nemo_gym_dir") or os.path.expanduser("~/Gym")
|
||||
auto_clone = cfg.get("nemo_gym_auto_clone", True)
|
||||
auto_start = cfg.get("nemo_gym_auto_start", True)
|
||||
head_port = cfg.get("nemo_gym_head_port", 11000)
|
||||
server_timeout = cfg.get("nemo_gym_server_timeout", 360)
|
||||
|
||||
from .server import (
|
||||
ensure_gym_repo,
|
||||
ensure_gym_venv,
|
||||
get_agent_servers,
|
||||
get_server_base_url,
|
||||
get_server_configs,
|
||||
get_verify_endpoint,
|
||||
start_servers,
|
||||
wait_for_resource_servers,
|
||||
)
|
||||
|
||||
self._gym_dir = ensure_gym_repo(gym_dir, auto_clone=auto_clone)
|
||||
|
||||
if auto_start:
|
||||
config_paths = cfg.get("nemo_gym_config_paths", [])
|
||||
ensure_gym_venv(self._gym_dir)
|
||||
start_servers(
|
||||
self._gym_dir,
|
||||
config_paths,
|
||||
head_port=head_port,
|
||||
timeout=server_timeout,
|
||||
)
|
||||
|
||||
self._global_config = get_server_configs(head_port=head_port)
|
||||
wait_for_resource_servers(self._global_config, timeout=server_timeout)
|
||||
|
||||
# Build endpoint maps for resource servers (/verify)
|
||||
self._verify_endpoints = {}
|
||||
self._server_base_urls = {}
|
||||
for server_name in self._global_config:
|
||||
try:
|
||||
self._verify_endpoints[server_name] = get_verify_endpoint(
|
||||
self._global_config, server_name
|
||||
)
|
||||
self._server_base_urls[server_name] = get_server_base_url(
|
||||
self._global_config, server_name
|
||||
)
|
||||
except (ValueError, KeyError, TypeError):
|
||||
pass
|
||||
|
||||
# Discover agent servers (/run) for multi-turn
|
||||
self._agent_servers = get_agent_servers(self._global_config)
|
||||
|
||||
# Pre-build dataset lookup for multi-turn (needs to happen at register time,
|
||||
# not load_datasets, because load_datasets may not be called if axolotl config
|
||||
# has its own datasets field)
|
||||
if cfg.get("nemo_gym_multi_turn") and cfg.get("nemo_gym_datasets"):
|
||||
from .dataset import load_nemo_gym_datasets
|
||||
|
||||
gym_dir = cfg.get("nemo_gym_dir") or os.path.expanduser("~/Gym")
|
||||
dataset = load_nemo_gym_datasets(gym_dir, cfg["nemo_gym_datasets"])
|
||||
self._dataset_lookup = {}
|
||||
for i in range(len(dataset)):
|
||||
row = dataset[i]
|
||||
# Use last message content as key (matches data_producer lookup)
|
||||
prompt_text = row["prompt"][-1]["content"]
|
||||
self._dataset_lookup[prompt_text] = row
|
||||
LOG.info(f"Built dataset lookup with {len(self._dataset_lookup)} entries")
|
||||
|
||||
multi_turn = cfg.get("nemo_gym_multi_turn", False)
|
||||
LOG.info(
|
||||
f"NeMo Gym ready with servers: {list(self._verify_endpoints.keys())} "
|
||||
f"(multi_turn={'enabled' if multi_turn else 'disabled'})"
|
||||
)
|
||||
|
||||
def load_datasets(self, cfg, preprocess=False) -> Union["TrainDatasetMeta", None]:
|
||||
if not cfg.nemo_gym_enabled:
|
||||
return None
|
||||
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
|
||||
from .dataset import load_nemo_gym_datasets
|
||||
|
||||
dataset_configs = cfg.nemo_gym_datasets
|
||||
dataset = load_nemo_gym_datasets(self._gym_dir, dataset_configs)
|
||||
|
||||
# Build prompt → row lookup for multi-turn rollout_func
|
||||
# (rollout_func only receives prompt text, needs to look up row data)
|
||||
self._dataset_lookup = {}
|
||||
for i in range(len(dataset)):
|
||||
row = dataset[i]
|
||||
# Use last message content as key (matches data_producer lookup)
|
||||
prompt_text = row["prompt"][-1]["content"]
|
||||
self._dataset_lookup[prompt_text] = row
|
||||
|
||||
return TrainDatasetMeta(
|
||||
train_dataset=dataset,
|
||||
eval_dataset=None,
|
||||
total_num_steps=0, # computed later by the builder
|
||||
)
|
||||
|
||||
def get_training_args(self, cfg):
|
||||
"""Pass through vLLM settings and force async trainer for multi-turn."""
|
||||
args = {}
|
||||
# Pass vLLM settings from vllm config block to TRL training args
|
||||
if cfg.vllm:
|
||||
vllm_cfg = cfg.vllm
|
||||
max_len = getattr(vllm_cfg, "max_model_len", None)
|
||||
gpu_util = getattr(vllm_cfg, "gpu_memory_utilization", None)
|
||||
tp_size = getattr(vllm_cfg, "tensor_parallel_size", None)
|
||||
if max_len:
|
||||
args["vllm_max_model_length"] = max_len
|
||||
if gpu_util:
|
||||
args["vllm_gpu_memory_utilization"] = gpu_util
|
||||
if tp_size:
|
||||
args["vllm_tensor_parallel_size"] = tp_size
|
||||
|
||||
# Force async trainer for multi-turn: NemoGymDataProducer needs the
|
||||
# data producer protocol. Setting use_data_producer=True selects
|
||||
# AxolotlAsyncGRPOTrainer which supports _create_data_producer().
|
||||
# With async_prefetch=False this runs synchronously — no threading.
|
||||
if cfg.nemo_gym_multi_turn and self._agent_servers:
|
||||
args["use_data_producer"] = True
|
||||
LOG.info(
|
||||
"NeMo Gym multi-turn: forcing use_data_producer=True for data producer protocol"
|
||||
)
|
||||
|
||||
# Dataloader workers fork subprocesses that can't handle the async
|
||||
# HTTP connections to NeMo Gym agents. Force num_workers=0.
|
||||
if getattr(cfg, "dataloader_num_workers", None) not in (None, 0):
|
||||
LOG.warning(
|
||||
f"NeMo Gym: overriding dataloader_num_workers={cfg.dataloader_num_workers} → 0 "
|
||||
"(forked workers can't use NeMo Gym agent connections)"
|
||||
)
|
||||
cfg.dataloader_num_workers = 0
|
||||
|
||||
if args:
|
||||
LOG.info(f"NeMo Gym plugin injecting training args: {args}")
|
||||
return args if args else None
|
||||
|
||||
def post_trainer_create(self, cfg, trainer):
|
||||
"""Wire NeMo Gym into the trainer (reward_fn or rollout_func)."""
|
||||
if not cfg.nemo_gym_enabled:
|
||||
return
|
||||
|
||||
model_name = cfg.nemo_gym_model_name or cfg.base_model or "axolotl-model"
|
||||
verify_timeout = cfg.nemo_gym_verify_timeout or 30
|
||||
multi_turn = cfg.nemo_gym_multi_turn or False
|
||||
|
||||
# Handle weight sync. NeMo Gym skips NCCL init, so we need to either:
|
||||
# - Install LoRA sync (when vllm_lora_sync=True)
|
||||
# - Or no-op sync_weights (when using standard vLLM server)
|
||||
trl_cfg = getattr(cfg, "trl", None)
|
||||
if hasattr(trainer, "vllm_generation") and trainer.vllm_generation:
|
||||
vllm_gen = trainer.vllm_generation
|
||||
if trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False):
|
||||
self._setup_lora_sync(trainer)
|
||||
# Verify the vLLM server supports runtime LoRA loading
|
||||
self._check_lora_endpoint(vllm_gen)
|
||||
else:
|
||||
# No NCCL, no LoRA sync — skip all weight sync paths
|
||||
vllm_gen.sync_weights = lambda: LOG.debug(
|
||||
"Weight sync skipped (NeMo Gym mode)"
|
||||
)
|
||||
type(vllm_gen).sync_weights = lambda self: LOG.debug(
|
||||
"Weight sync skipped (NeMo Gym mode)"
|
||||
)
|
||||
# Also patch the async trainer's internal sync method
|
||||
if hasattr(trainer, "_maybe_sync_vllm_weights"):
|
||||
trainer._maybe_sync_vllm_weights = lambda: LOG.debug(
|
||||
"Async weight sync skipped (NeMo Gym mode)"
|
||||
)
|
||||
LOG.info("Disabled weight sync (NeMo Gym mode, no LoRA sync)")
|
||||
|
||||
if multi_turn:
|
||||
self._wire_multi_turn(cfg, trainer, model_name, verify_timeout)
|
||||
else:
|
||||
self._wire_single_turn(trainer, model_name, verify_timeout)
|
||||
|
||||
def _wire_single_turn(self, trainer, model_name, verify_timeout):
|
||||
"""Inject single-turn reward function into the trainer."""
|
||||
from .rewards import create_nemo_gym_reward_fn
|
||||
|
||||
self._reward_fn = create_nemo_gym_reward_fn(
|
||||
global_config=self._global_config,
|
||||
verify_endpoints=self._verify_endpoints,
|
||||
model_name=model_name,
|
||||
verify_timeout=verify_timeout,
|
||||
)
|
||||
|
||||
if hasattr(trainer, "reward_funcs"):
|
||||
trainer.reward_funcs.append(self._reward_fn)
|
||||
trainer.reward_func_names.append("nemo_gym")
|
||||
trainer.reward_processing_classes.append(None)
|
||||
LOG.info(
|
||||
f"Added NeMo Gym reward function (single-turn). "
|
||||
f"Total reward functions: {len(trainer.reward_funcs)}"
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
"Trainer does not have reward_funcs attribute. "
|
||||
"NeMo Gym reward function not injected. "
|
||||
"Ensure you are using a GRPO trainer."
|
||||
)
|
||||
|
||||
def _wire_multi_turn(self, cfg, trainer, model_name, verify_timeout):
|
||||
"""Replace the data producer with NemoGymDataProducer.
|
||||
|
||||
The plugin forces use_data_producer=True (in get_training_args) which
|
||||
selects AxolotlAsyncGRPOTrainer. Here we swap its data_producer with
|
||||
our NemoGymDataProducer that calls agent /run instead of vLLM generate.
|
||||
"""
|
||||
if not self._agent_servers:
|
||||
LOG.warning(
|
||||
"No NeMo Gym agent servers discovered. Multi-turn requires agent servers "
|
||||
"started via ng_run with an agent config. Falling back to single-turn."
|
||||
)
|
||||
self._wire_single_turn(trainer, model_name, verify_timeout)
|
||||
return
|
||||
|
||||
if not hasattr(trainer, "data_producer") or trainer.data_producer is None:
|
||||
LOG.warning(
|
||||
"Trainer has no data_producer. NeMo Gym multi-turn requires "
|
||||
"use_data_producer=true (should be auto-set by plugin)."
|
||||
)
|
||||
return
|
||||
|
||||
from axolotl.core.trainers.grpo.async_trainer import AsyncDataProducer
|
||||
|
||||
from .data_producer import NemoGymDataProducer
|
||||
|
||||
# Get the current producer's config and params
|
||||
current = trainer.data_producer
|
||||
# Unwrap AsyncDataProducer to get the inner producer's config
|
||||
if isinstance(current, AsyncDataProducer):
|
||||
inner = current._inner
|
||||
else:
|
||||
inner = current
|
||||
|
||||
nemo_producer = NemoGymDataProducer(
|
||||
config=inner.config,
|
||||
prompt_dataset=inner._dataset,
|
||||
num_generations=inner._num_generations,
|
||||
generation_batch_size=inner._generation_batch_size,
|
||||
train_batch_size=inner._train_batch_size,
|
||||
steps_per_generation=inner._steps_per_generation,
|
||||
shuffle_dataset=inner._shuffle_dataset,
|
||||
seed=inner._seed,
|
||||
agent_servers=self._agent_servers,
|
||||
dataset_lookup=self._dataset_lookup or {},
|
||||
request_timeout=float(cfg.nemo_gym_run_timeout or 300),
|
||||
)
|
||||
nemo_producer.set_trainer(trainer)
|
||||
|
||||
# Re-wrap in AsyncDataProducer if async prefetch is enabled
|
||||
if getattr(trainer.args, "async_prefetch", False):
|
||||
nemo_producer = AsyncDataProducer(
|
||||
nemo_producer,
|
||||
background_produce_kwargs={"skip_policy_logps": True},
|
||||
)
|
||||
|
||||
trainer.data_producer = nemo_producer
|
||||
LOG.info(
|
||||
f"NeMo Gym data producer installed "
|
||||
f"(agent servers: {list(self._agent_servers.keys())}, "
|
||||
f"async={'yes' if getattr(trainer.args, 'async_prefetch', False) else 'no'})"
|
||||
)
|
||||
|
||||
# Passthrough reward function — agent /run already computed rewards
|
||||
from .rewards import reward_env
|
||||
|
||||
if hasattr(trainer, "reward_funcs"):
|
||||
trainer.reward_funcs.append(reward_env)
|
||||
trainer.reward_func_names.append("nemo_gym")
|
||||
trainer.reward_processing_classes.append(None)
|
||||
|
||||
@staticmethod
|
||||
def _check_lora_endpoint(vllm_gen):
|
||||
"""Verify the vLLM server supports runtime LoRA loading."""
|
||||
import requests as http_requests
|
||||
|
||||
if not hasattr(vllm_gen, "vllm_client") or vllm_gen.vllm_client is None:
|
||||
return # Non-main rank in multi-GPU — client only exists on rank 0
|
||||
base_url = vllm_gen.vllm_client.base_url
|
||||
try:
|
||||
# Send a dummy load request — if the endpoint exists, we get a
|
||||
# proper error (400/404 about the adapter), not a route 404.
|
||||
resp = http_requests.post(
|
||||
f"{base_url}/v1/load_lora_adapter",
|
||||
json={"lora_name": "__probe__", "lora_path": "/nonexistent"},
|
||||
timeout=5,
|
||||
)
|
||||
if (
|
||||
resp.status_code == 404
|
||||
and "Not Found" in resp.text
|
||||
and "adapter" not in resp.text.lower()
|
||||
):
|
||||
LOG.warning(
|
||||
"vLLM server does not expose /v1/load_lora_adapter. "
|
||||
"Set VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 when starting vLLM, e.g.:\n"
|
||||
" VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 python -m vllm.entrypoints.openai.api_server "
|
||||
"--enable-lora --max-lora-rank 64 ..."
|
||||
)
|
||||
except Exception:
|
||||
pass # Server might not be up yet, sync will warn later
|
||||
|
||||
def _setup_lora_sync(self, trainer):
|
||||
"""Replace sync_weights with LoRA adapter sync via filesystem + HTTP.
|
||||
|
||||
If the async trainer is detected (has ``_sync_lora_adapter``), delegates
|
||||
to it — that method already handles multi-GPU (FSDP/DeepSpeed state_dict
|
||||
gather, broadcast sync dir, barrier).
|
||||
|
||||
Otherwise installs a standalone closure for the non-async GRPO path that
|
||||
saves the adapter and POSTs to ``/v1/load_lora_adapter``.
|
||||
"""
|
||||
vllm_gen = trainer.vllm_generation
|
||||
|
||||
# Async trainer path: delegate to its _sync_lora_adapter (multi-GPU safe)
|
||||
if hasattr(trainer, "_sync_lora_adapter"):
|
||||
|
||||
def lora_sync_weights():
|
||||
trainer._sync_lora_adapter()
|
||||
|
||||
vllm_gen.sync_weights = lora_sync_weights
|
||||
type(vllm_gen).sync_weights = lambda self: lora_sync_weights()
|
||||
LOG.info(
|
||||
"Installed LoRA adapter sync "
|
||||
"(delegates to async trainer._sync_lora_adapter)"
|
||||
)
|
||||
return
|
||||
|
||||
# Non-async standard GRPO path: standalone closure
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import requests as http_requests
|
||||
|
||||
base_model = getattr(trainer.args, "model_name_or_path", None) or "axolotl-lora"
|
||||
sync_state = {"version": 0, "sync_dir": tempfile.mkdtemp(prefix="lora_sync_")}
|
||||
|
||||
def lora_sync_weights():
|
||||
"""Save LoRA adapter and load it into vLLM."""
|
||||
accelerator = vllm_gen.accelerator
|
||||
model = vllm_gen.model
|
||||
|
||||
if vllm_gen.mode != "server":
|
||||
return
|
||||
|
||||
sync_state["version"] += 1
|
||||
version = sync_state["version"]
|
||||
adapter_path = os.path.join(sync_state["sync_dir"], f"v{version}")
|
||||
|
||||
wrapped_model = getattr(trainer, "model_wrapped", model)
|
||||
state_dict = accelerator.get_state_dict(wrapped_model)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
unwrapped = accelerator.unwrap_model(model)
|
||||
unwrapped.save_pretrained(adapter_path, state_dict=state_dict)
|
||||
|
||||
base_url = vllm_gen.vllm_client.base_url
|
||||
resp = http_requests.post(
|
||||
f"{base_url}/v1/load_lora_adapter",
|
||||
json={
|
||||
"lora_name": base_model,
|
||||
"lora_path": adapter_path,
|
||||
"load_inplace": True,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
resp = http_requests.post(
|
||||
f"{base_url}/set_lora_adapter/",
|
||||
json={
|
||||
"lora_name": "active_lora",
|
||||
"lora_int_id": version,
|
||||
"lora_path": adapter_path,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
LOG.warning(
|
||||
f"Failed to set LoRA adapter: "
|
||||
f"{resp.status_code} {resp.text}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
vllm_gen.vllm_client.reset_prefix_cache()
|
||||
except Exception as exc:
|
||||
LOG.warning("Failed to reset prefix cache: %s", exc)
|
||||
|
||||
if version > 1:
|
||||
old = os.path.join(sync_state["sync_dir"], f"v{version - 1}")
|
||||
if os.path.exists(old):
|
||||
shutil.rmtree(old, ignore_errors=True)
|
||||
|
||||
LOG.info(f"Synced LoRA adapter v{version} to vLLM ({adapter_path})")
|
||||
|
||||
if accelerator.num_processes > 1:
|
||||
import torch.distributed as dist
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
vllm_gen.sync_weights = lora_sync_weights
|
||||
type(vllm_gen).sync_weights = lambda self: lora_sync_weights()
|
||||
LOG.info("Installed LoRA adapter sync (standalone fallback)")
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
"""Cleanup NeMo Gym servers if we started them."""
|
||||
if cfg.get("nemo_gym_enabled") and cfg.get("nemo_gym_auto_start", True):
|
||||
from .server import _cleanup_servers
|
||||
|
||||
_cleanup_servers()
|
||||
274
src/axolotl/integrations/nemo_gym/rewards.py
Normal file
274
src/axolotl/integrations/nemo_gym/rewards.py
Normal file
@@ -0,0 +1,274 @@
|
||||
# Copyright 2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
"""
|
||||
NeMo Gym reward functions.
|
||||
|
||||
Provides ready-to-use reward functions for axolotl configs::
|
||||
|
||||
trl:
|
||||
reward_funcs:
|
||||
# Multi-turn: passthrough reward from agent /run
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_env
|
||||
# Single-turn: call /verify endpoints directly
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-turn passthrough reward
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def reward_env(completions, prompts=None, **kwargs):
|
||||
"""Passthrough: extract pre-computed reward from NeMo Gym agent /run response.
|
||||
|
||||
The ``NemoGymDataProducer`` injects ``env_reward`` into each sample's
|
||||
kwargs after the agent returns from ``/run``. This function simply
|
||||
forwards that value so TRL can log it alongside other reward signals.
|
||||
|
||||
Use this in your config when ``nemo_gym_multi_turn: true``::
|
||||
|
||||
trl:
|
||||
reward_funcs:
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_env
|
||||
"""
|
||||
env_rewards = kwargs.get("env_reward")
|
||||
if env_rewards is not None:
|
||||
if isinstance(env_rewards, (list, tuple)):
|
||||
return [float(r) for r in env_rewards]
|
||||
return [float(env_rewards) for _ in completions]
|
||||
return [0.0 for _ in completions]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single-turn /verify reward
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Module-level cache for discovered verify URLs
|
||||
_verify_urls: dict[str, str] = {}
|
||||
_verify_urls_lock = __import__("threading").Lock()
|
||||
|
||||
|
||||
def _get_verify_urls(head_port: int = 11000) -> dict[str, str]:
|
||||
"""Discover verify endpoints from the NeMo Gym head server.
|
||||
|
||||
Results are cached so that the HTTP round-trip only happens once per
|
||||
process. A lock guards against concurrent discovery from multiple
|
||||
threads (e.g. async_prefetch background thread + main training thread).
|
||||
"""
|
||||
global _verify_urls
|
||||
if _verify_urls:
|
||||
return _verify_urls
|
||||
|
||||
with _verify_urls_lock:
|
||||
# Double-check after acquiring lock
|
||||
if _verify_urls:
|
||||
return _verify_urls
|
||||
|
||||
import yaml
|
||||
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"http://127.0.0.1:{head_port}/global_config_dict_yaml", timeout=5
|
||||
)
|
||||
config = yaml.safe_load(resp.text)
|
||||
if isinstance(config, str):
|
||||
config = yaml.safe_load(config)
|
||||
for _name, cfg in config.items():
|
||||
if not isinstance(cfg, dict):
|
||||
continue
|
||||
for srv_name, srv_cfg in cfg.get("resources_servers", {}).items():
|
||||
if (
|
||||
isinstance(srv_cfg, dict)
|
||||
and "host" in srv_cfg
|
||||
and "port" in srv_cfg
|
||||
):
|
||||
_verify_urls[srv_name] = (
|
||||
f"http://{srv_cfg['host']}:{srv_cfg['port']}/verify"
|
||||
)
|
||||
except Exception as exc:
|
||||
LOG.warning(f"Failed to discover NeMo Gym verify endpoints: {exc}")
|
||||
|
||||
return _verify_urls
|
||||
|
||||
|
||||
def reward_nemo_gym_verify(completions, prompts=None, **kwargs):
|
||||
"""Call NeMo Gym ``/verify`` endpoint for each completion (single-turn).
|
||||
|
||||
Requires ``resources_server_ref`` and ``verify_extra`` kwargs, which the
|
||||
NeMo Gym dataset loader injects automatically.
|
||||
|
||||
Use this in your config when ``nemo_gym_multi_turn: false``::
|
||||
|
||||
trl:
|
||||
reward_funcs:
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
|
||||
"""
|
||||
verify_urls = _get_verify_urls()
|
||||
refs = kwargs.get("resources_server_ref", [])
|
||||
extras = kwargs.get("verify_extra", [])
|
||||
scores = []
|
||||
|
||||
for i, completion in enumerate(completions):
|
||||
text = completion[0]["content"] if completion else ""
|
||||
prompt = prompts[i][0]["content"] if prompts and i < len(prompts) else ""
|
||||
srv_name = (
|
||||
refs[i]["name"] if i < len(refs) and isinstance(refs[i], dict) else ""
|
||||
)
|
||||
url = verify_urls.get(srv_name, "")
|
||||
|
||||
if not url:
|
||||
scores.append(0.0)
|
||||
continue
|
||||
|
||||
extra = extras[i] if i < len(extras) else {}
|
||||
req = {k: v for k, v in extra.items() if v is not None}
|
||||
req["responses_create_params"] = {
|
||||
"input": [{"role": "user", "content": prompt}]
|
||||
}
|
||||
req["response"] = {
|
||||
"id": "resp",
|
||||
"created_at": 0,
|
||||
"model": "axolotl",
|
||||
"object": "response",
|
||||
"output": [
|
||||
{
|
||||
"id": "msg",
|
||||
"role": "assistant",
|
||||
"type": "message",
|
||||
"status": "completed",
|
||||
"content": [
|
||||
{"type": "output_text", "text": text, "annotations": []}
|
||||
],
|
||||
}
|
||||
],
|
||||
"parallel_tool_calls": True,
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(url, json=req, timeout=30)
|
||||
reward = resp.json().get("reward", 0.0) if resp.ok else 0.0
|
||||
except Exception as exc:
|
||||
LOG.warning(f"Verify request to {url} failed: {exc}")
|
||||
reward = 0.0
|
||||
|
||||
scores.append(float(reward))
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory used internally by the plugin
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_nemo_gym_reward_fn(
|
||||
global_config: dict,
|
||||
verify_endpoints: dict[str, str],
|
||||
model_name: str = "axolotl-model",
|
||||
verify_timeout: int = 30,
|
||||
):
|
||||
"""Create a reward function bound to specific verify endpoints.
|
||||
|
||||
Used internally by ``NemoGymPlugin._wire_single_turn()`` to inject a
|
||||
reward function that already knows the endpoint map (no discovery needed).
|
||||
"""
|
||||
|
||||
def reward_fn(
|
||||
completions: list[list[dict[str, str]]],
|
||||
prompts: list[list[dict[str, str]]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> np.ndarray:
|
||||
resources_server_refs = kwargs.get("resources_server_ref", [])
|
||||
verify_extras = kwargs.get("verify_extra", [])
|
||||
|
||||
scores = []
|
||||
for i, completion in enumerate(completions):
|
||||
completion_text = completion[0]["content"]
|
||||
task_prompt = prompts[i][0]["content"] if prompts else ""
|
||||
|
||||
server_name = (
|
||||
resources_server_refs[i]["name"]
|
||||
if i < len(resources_server_refs)
|
||||
else None
|
||||
)
|
||||
|
||||
if server_name is None or server_name not in verify_endpoints:
|
||||
LOG.warning(
|
||||
f"No verify endpoint for server '{server_name}', returning 0 reward"
|
||||
)
|
||||
scores.append(0.0)
|
||||
continue
|
||||
|
||||
verify_endpoint = verify_endpoints[server_name]
|
||||
|
||||
verify_request = (
|
||||
{k: v for k, v in verify_extras[i].items() if v is not None}
|
||||
if i < len(verify_extras)
|
||||
else {}
|
||||
)
|
||||
verify_request["responses_create_params"] = {
|
||||
"input": [{"role": "user", "content": task_prompt}]
|
||||
}
|
||||
verify_request["response"] = {
|
||||
"id": "resp",
|
||||
"created_at": 0,
|
||||
"model": model_name,
|
||||
"object": "response",
|
||||
"output": [
|
||||
{
|
||||
"id": "msg",
|
||||
"role": "assistant",
|
||||
"type": "message",
|
||||
"status": "completed",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": completion_text,
|
||||
"annotations": [],
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"parallel_tool_calls": True,
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
verify_endpoint, json=verify_request, timeout=verify_timeout
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
reward = resp.json().get("reward", 0.0)
|
||||
else:
|
||||
LOG.warning(
|
||||
f"Verify request returned status {resp.status_code}: {resp.text[:200]}"
|
||||
)
|
||||
reward = 0.0
|
||||
except requests.exceptions.RequestException as exc:
|
||||
LOG.warning(f"Verify request failed: {exc}")
|
||||
reward = 0.0
|
||||
|
||||
scores.append(float(reward))
|
||||
|
||||
return np.array(scores)
|
||||
|
||||
return reward_fn
|
||||
242
src/axolotl/integrations/nemo_gym/server.py
Normal file
242
src/axolotl/integrations/nemo_gym/server.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# Copyright 2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
"""
|
||||
NeMo Gym server lifecycle management.
|
||||
|
||||
Handles cloning the NeMo Gym repo, starting resource servers,
|
||||
waiting for readiness, and cleanup on exit.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
import time
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
_ng_process = None
|
||||
_ng_log_file = None
|
||||
|
||||
|
||||
def ensure_gym_repo(gym_dir: str, auto_clone: bool = True) -> str:
|
||||
"""Clone the NeMo Gym repo if it doesn't exist.
|
||||
|
||||
Args:
|
||||
gym_dir: Path to the NeMo Gym directory.
|
||||
auto_clone: Whether to auto-clone if missing.
|
||||
|
||||
Returns:
|
||||
Resolved path to the NeMo Gym directory.
|
||||
"""
|
||||
gym_dir = os.path.expanduser(gym_dir)
|
||||
if os.path.exists(gym_dir):
|
||||
LOG.info(f"NeMo Gym directory exists at {gym_dir}")
|
||||
return gym_dir
|
||||
|
||||
if not auto_clone:
|
||||
raise FileNotFoundError(
|
||||
f"NeMo Gym directory not found at {gym_dir} and auto_clone is disabled."
|
||||
)
|
||||
|
||||
LOG.info(f"Cloning NeMo Gym to {gym_dir}...")
|
||||
subprocess.run( # nosec
|
||||
["git", "clone", "https://github.com/NVIDIA-NeMo/Gym.git", gym_dir],
|
||||
check=True,
|
||||
)
|
||||
return gym_dir
|
||||
|
||||
|
||||
def ensure_gym_venv(gym_dir: str):
|
||||
"""Set up the NeMo Gym Python venv if not present."""
|
||||
venv_python = os.path.join(gym_dir, ".venv", "bin", "python")
|
||||
if os.path.exists(venv_python):
|
||||
return
|
||||
|
||||
LOG.info("Setting up NeMo Gym venv...")
|
||||
subprocess.run(["uv", "venv", "--python", "3.12"], cwd=gym_dir, check=True) # nosec
|
||||
subprocess.run( # nosec
|
||||
["bash", "-c", "source .venv/bin/activate && uv sync"],
|
||||
cwd=gym_dir,
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
def start_servers(
|
||||
gym_dir: str,
|
||||
config_paths: list[str],
|
||||
head_port: int = 11000,
|
||||
timeout: int = 360,
|
||||
):
|
||||
"""Start NeMo Gym resource servers via ng_run.
|
||||
|
||||
Args:
|
||||
gym_dir: Path to the NeMo Gym directory.
|
||||
config_paths: List of config YAML paths relative to gym_dir.
|
||||
head_port: Port for the head server.
|
||||
timeout: Max seconds to wait for servers.
|
||||
"""
|
||||
global _ng_process, _ng_log_file
|
||||
|
||||
head_url = f"http://127.0.0.1:{head_port}/global_config_dict_yaml"
|
||||
|
||||
# Check if already running
|
||||
try:
|
||||
requests.get(head_url, timeout=2)
|
||||
LOG.info(f"NeMo Gym servers already running on port {head_port}.")
|
||||
return
|
||||
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
||||
pass
|
||||
|
||||
ng_run_bin = os.path.join(gym_dir, ".venv", "bin", "ng_run")
|
||||
config_arg = f"+config_paths=[{','.join(config_paths)}]"
|
||||
_ng_log_file = open(os.path.join(gym_dir, "ng_run.log"), "w") # noqa: SIM115
|
||||
_ng_process = subprocess.Popen( # nosec B603
|
||||
[ng_run_bin, config_arg, "+skip_venv_if_present=true"],
|
||||
cwd=gym_dir,
|
||||
stdout=_ng_log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
|
||||
atexit.register(_cleanup_servers)
|
||||
|
||||
LOG.info("Waiting for NeMo Gym head server...")
|
||||
for _ in range(timeout // 3):
|
||||
try:
|
||||
requests.get(head_url, timeout=2)
|
||||
LOG.info("NeMo Gym head server is ready.")
|
||||
return
|
||||
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
||||
if _ng_process.poll() is not None:
|
||||
raise RuntimeError(
|
||||
"NeMo Gym server process exited unexpectedly. "
|
||||
f"Check {gym_dir}/ng_run.log for details."
|
||||
) from None
|
||||
time.sleep(3)
|
||||
|
||||
raise RuntimeError(
|
||||
f"NeMo Gym servers did not start within {timeout}s. "
|
||||
f"Check {gym_dir}/ng_run.log for details."
|
||||
)
|
||||
|
||||
|
||||
def get_server_configs(head_port: int = 11000) -> dict:
|
||||
"""Fetch the global config from the NeMo Gym head server.
|
||||
|
||||
Returns:
|
||||
Dict mapping server_name -> server config.
|
||||
"""
|
||||
response = requests.get(
|
||||
f"http://127.0.0.1:{head_port}/global_config_dict_yaml", timeout=5
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = yaml.safe_load(response.text)
|
||||
# NeMo Gym head server double-encodes: YAML string inside a YAML string
|
||||
if isinstance(result, str):
|
||||
result = yaml.safe_load(result)
|
||||
return result
|
||||
|
||||
|
||||
def get_agent_servers(
|
||||
global_config: dict, head_host: str = "127.0.0.1"
|
||||
) -> dict[str, str]:
|
||||
"""Discover NeMo Gym agent servers from the global config.
|
||||
|
||||
Agent servers handle multi-turn orchestration via /run endpoint.
|
||||
Returns mapping of agent_name → URL (e.g., {"simple_agent": "http://host:port"}).
|
||||
"""
|
||||
agents = {}
|
||||
for top_name, top_cfg in global_config.items():
|
||||
if not isinstance(top_cfg, dict):
|
||||
continue
|
||||
agent_dict = top_cfg.get("responses_api_agents", {})
|
||||
if not agent_dict:
|
||||
continue
|
||||
for _agent_name, agent_cfg in agent_dict.items():
|
||||
if not isinstance(agent_cfg, dict):
|
||||
continue
|
||||
host = agent_cfg.get("host", "127.0.0.1")
|
||||
port = agent_cfg.get("port")
|
||||
if not port:
|
||||
continue
|
||||
# Replace loopback with head_host for remote access
|
||||
host = _normalize_host(host, fallback=head_host)
|
||||
# Use the top-level config name (not the inner agent name)
|
||||
# because dataset agent_ref.name references the top-level name
|
||||
agents[top_name] = f"http://{host}:{port}"
|
||||
if agents:
|
||||
LOG.info(f"Discovered NeMo Gym agent servers: {agents}")
|
||||
return agents
|
||||
|
||||
|
||||
def _normalize_host(host: str, fallback: str = "127.0.0.1") -> str:
|
||||
"""Normalize bind-all and loopback addresses for reachability."""
|
||||
if host in ("0.0.0.0", "localhost"): # nosec B104
|
||||
return fallback
|
||||
return host
|
||||
|
||||
|
||||
def get_server_base_url(global_config: dict, server_name: str) -> str:
|
||||
"""Get the base URL for a given resource server."""
|
||||
try:
|
||||
srv_cfg = global_config[server_name]["resources_servers"][server_name]
|
||||
host = _normalize_host(srv_cfg["host"])
|
||||
return f"http://{host}:{srv_cfg['port']}"
|
||||
except (KeyError, TypeError) as exc:
|
||||
raise ValueError(
|
||||
f"Could not find resource server config for '{server_name}' in NeMo Gym. "
|
||||
f"Available servers: {list(global_config.keys())}"
|
||||
) from exc
|
||||
|
||||
|
||||
def get_verify_endpoint(global_config: dict, server_name: str) -> str:
|
||||
"""Get the /verify endpoint URL for a given resource server."""
|
||||
return f"{get_server_base_url(global_config, server_name)}/verify"
|
||||
|
||||
|
||||
def wait_for_resource_servers(global_config: dict, timeout: int = 180):
|
||||
"""Wait for all resource servers in the config to become reachable."""
|
||||
for srv_name in global_config:
|
||||
try:
|
||||
srv_cfg = global_config[srv_name]["resources_servers"][srv_name]
|
||||
except (KeyError, TypeError):
|
||||
continue # Skip non-server config entries silently
|
||||
|
||||
host, port = _normalize_host(srv_cfg["host"]), srv_cfg["port"]
|
||||
LOG.info(f"Waiting for resource server '{srv_name}' at {host}:{port}...")
|
||||
for _ in range(timeout // 2):
|
||||
try:
|
||||
requests.get(f"http://{host}:{port}/", timeout=2)
|
||||
LOG.info(f"Resource server '{srv_name}' is ready.")
|
||||
break
|
||||
except requests.exceptions.ConnectionError:
|
||||
time.sleep(2)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Resource server '{srv_name}' at {host}:{port} "
|
||||
f"did not start within {timeout}s."
|
||||
)
|
||||
|
||||
|
||||
def _cleanup_servers():
|
||||
"""Terminate NeMo Gym server process on exit."""
|
||||
global _ng_process, _ng_log_file
|
||||
if _ng_process is not None and _ng_process.poll() is None:
|
||||
LOG.info("Terminating NeMo Gym servers...")
|
||||
_ng_process.terminate()
|
||||
try:
|
||||
_ng_process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
_ng_process.kill()
|
||||
if _ng_log_file is not None:
|
||||
_ng_log_file.close()
|
||||
@@ -446,6 +446,114 @@ def main(script_args: ScriptArguments):
|
||||
"logprob_token_ids": extract_logprobs(all_outputs)[1],
|
||||
}
|
||||
|
||||
# --- OpenAI-compatible endpoints (for NeMo Gym agent integration) ---
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
"""OpenAI-compatible models endpoint."""
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{"id": script_args.model, "object": "model", "owned_by": "axolotl"}
|
||||
],
|
||||
}
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def openai_chat_completions(request_body: dict):
|
||||
"""OpenAI-compatible chat completions endpoint.
|
||||
|
||||
Translates OpenAI format to our internal /chat/ format so NeMo Gym's
|
||||
model server proxy can call us directly.
|
||||
"""
|
||||
messages_list = request_body.get("messages", [])
|
||||
temperature = request_body.get("temperature", 1.0)
|
||||
max_tokens = request_body.get("max_tokens", 512)
|
||||
top_p = request_body.get("top_p", 1.0)
|
||||
n = request_body.get("n", 1)
|
||||
|
||||
generation_kwargs = {
|
||||
"n": n,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens,
|
||||
"logprobs": 0, # Always return logprobs (NeMo Gym needs them)
|
||||
}
|
||||
sampling_params = SamplingParams(
|
||||
**{k: v for k, v in generation_kwargs.items() if v is not None}
|
||||
)
|
||||
|
||||
# Send to vLLM worker
|
||||
chunked = chunk_list([messages_list], script_args.data_parallel_size)
|
||||
for conn, chunk in zip(connections, chunked, strict=True):
|
||||
if not chunk:
|
||||
chunk = [[{"role": "user", "content": "<placeholder>"}]]
|
||||
kwargs = {
|
||||
"messages": chunk,
|
||||
"sampling_params": sampling_params,
|
||||
"use_tqdm": False,
|
||||
"lora_request": active_lora["request"],
|
||||
}
|
||||
conn.send({"type": "call", "method": "chat", "kwargs": kwargs})
|
||||
|
||||
all_outputs = [conn.recv() for conn in connections]
|
||||
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
|
||||
all_outputs = list(chain.from_iterable(all_outputs))
|
||||
|
||||
if not all_outputs:
|
||||
return {"choices": [], "model": script_args.model}
|
||||
|
||||
# Format as OpenAI response
|
||||
import uuid
|
||||
|
||||
choices = []
|
||||
for i, output in enumerate(all_outputs):
|
||||
for j, out in enumerate(output.outputs):
|
||||
text = out.text
|
||||
# Extract token IDs if requested
|
||||
# Build logprobs in OpenAI format
|
||||
lp_list = None
|
||||
if out.logprobs:
|
||||
lp_list = {
|
||||
"content": [
|
||||
{"token": "", "logprob": next(iter(lp.values())).logprob} # nosec B105
|
||||
for lp in out.logprobs
|
||||
]
|
||||
}
|
||||
|
||||
choice = {
|
||||
"index": i * n + j,
|
||||
"message": {"role": "assistant", "content": text},
|
||||
"finish_reason": "stop"
|
||||
if out.finish_reason == "stop"
|
||||
else "length",
|
||||
"logprobs": lp_list,
|
||||
}
|
||||
# Include token ID information for NeMo Gym
|
||||
choice["prompt_token_ids"] = output.prompt_token_ids
|
||||
choice["generation_token_ids"] = list(out.token_ids)
|
||||
if out.logprobs:
|
||||
choice["generation_log_probs"] = [
|
||||
next(iter(lp.values())).logprob for lp in out.logprobs
|
||||
]
|
||||
choices.append(choice)
|
||||
|
||||
prompt_tokens = len(all_outputs[0].prompt_token_ids) if all_outputs else 0
|
||||
completion_tokens = sum(
|
||||
len(out.token_ids) for o in all_outputs for out in o.outputs
|
||||
)
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
"object": "chat.completion",
|
||||
"model": script_args.model,
|
||||
"choices": choices,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
|
||||
|
||||
@app.post("/init_communicator/")
|
||||
|
||||
Reference in New Issue
Block a user