diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 60c34933d..603b58466 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -18,6 +18,8 @@ feedback. Various methods include, but not limited to: - [Odds Ratio Preference Optimization (ORPO)](#orpo) - [Group Relative Policy Optimization (GRPO)](#grpo) - [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo) +- [Energy-Based Fine-Tuning (EBFT)](#ebft) +- [NeMo Gym Integration](#nemo-gym-integration) ## RLHF using Axolotl @@ -1037,6 +1039,302 @@ simpo_gamma: 0.5 # default in CPOTrainer This method uses the same dataset format as [DPO](#dpo). +### EBFT + +EBFT (Energy-Based Fine-Tuning) fine-tunes language models by optimizing a **feature-matching loss** rather than relying on external reward functions. A frozen copy of the model extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments. + +Paper: ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026) + +**Key advantages:** + +- No reward model or verifier required — works on any (prompt, completion) data +- Applicable to non-verifiable tasks (code, translation, creative writing) +- Operates on model rollouts (not teacher forcing), reducing distribution shift + +EBFT supports two modes: + +- **Structured mode**: For QA/instruction data with prompt + completion pairs. Uses vLLM for generation (like GRPO). +- **Strided mode**: For unstructured text without prompt/completion splits. Uses strided block-parallel generation with flex_attention — no vLLM needed. + +#### Structured Mode + +```yaml +base_model: Qwen/Qwen3-4B + +rl: ebft + +ebft: + feature_layers: [0.25, 0.5, 0.75] # Extract features at 25%, 50%, 75% depth + embed_method: last_token + use_whitening: false + alignment_coef: 1.0 # Cosine similarity reward weight + diversity_coef: 1.0 # Pairwise dot product penalty + ce_coef: 0.0 # Cross-entropy on GT tokens (0 = off) + +trl: + num_generations: 4 + max_completion_length: 256 + temperature: 0.7 + use_vllm: true + vllm_server_host: 0.0.0.0 + vllm_server_port: 8000 + vllm_lora_sync: true # LoRA adapter sync (recommended) + vllm_sync_interval: 3 + use_data_producer: true + async_prefetch: true # Set false for sync mode + scale_rewards: true + loss_type: grpo + epsilon: 0.2 + +vllm: + gpu_memory_utilization: 0.5 + max_model_len: 2048 + +datasets: + - path: nvidia/OpenCodeInstruct + type: ebft_opencode.transform + split: train[:500] + +adapter: lora +lora_r: 16 +lora_alpha: 32 +lora_target_linear: true +``` + +```bash +# Terminal 1: Start vLLM +CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml + +# Terminal 2: Train +CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml +``` + +#### Strided Mode + +For unstructured text (raw code, prose). No vLLM needed — runs on a single GPU. + +```yaml +base_model: meta-llama/Llama-3.2-1B + +rl: ebft + +ebft: + mode: strided + stride: 8 + context_length: 8 + generate_max_len: 8 + n_samples_per_prompt: 4 + temperature: 0.6 + feature_layers: [0.25, 0.5, 0.75] + embed_method: last_token + use_whitening: true + alignment_coef: 1.0 + diversity_coef: 1.0 + rl_coef: 1.0 + ce_coef: 0.03 + advantage_estimator: rloo + +datasets: + - path: nvidia/OpenCodeInstruct + type: ebft_strided_structured.transform + split: train[:1%] + +flash_attention: false +flex_attention: true # Strided mode uses flex_attention +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true # Required for flex_attention +``` + +```bash +CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml +``` + +::: {.callout-tip} +See `examples/ebft/` for complete example configs covering Llama 1B/3B/8B and Qwen3 4B/8B models in both modes. +::: + +#### EBFT Configuration Reference + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `ebft.feature_layers` | `[0.25, 0.5, 0.75]` | Layer depths for feature extraction (fractional) | +| `ebft.embed_method` | `last_token` | Feature pooling: `last_token`, `mean_pooling`, `concat` | +| `ebft.use_whitening` | `false` | SVD whitening of feature dimensions | +| `ebft.alignment_coef` | `1.0` | Cosine similarity reward weight | +| `ebft.diversity_coef` | `1.0` | Pairwise dot product penalty weight | +| `ebft.ce_coef` | `0.0` | Cross-entropy loss on ground-truth tokens | +| `ebft.mode` | `structured` | `structured` (vLLM) or `strided` (no vLLM) | +| `ebft.stride` | — | Tokens between anchor points (strided mode) | +| `ebft.context_length` | — | Context window per block (strided mode) | +| `ebft.generate_max_len` | — | Tokens to generate per block (strided mode) | +| `ebft.n_samples_per_prompt` | — | Rollouts per document (strided mode) | +| `ebft.advantage_estimator` | `grpo` | `grpo` or `rloo` (strided mode) | + +### NeMo Gym Integration + +[NeMo Gym](https://github.com/NVIDIA-NeMo/Gym) provides 50+ verified RL environments (math, coding, tool-use, reasoning) with deterministic reward signals. The axolotl integration supports both **single-turn** (call `/verify` after generation) and **multi-turn** (agent-based tool execution via `/run`). + +#### Single-Turn (Simplest) + +For environments that only need answer verification (math, coding challenges). No agent server needed — the reward function calls `/verify` directly on the resource server. + +```yaml +base_model: Qwen/Qwen2.5-0.5B-Instruct + +rl: grpo +chat_template: tokenizer_default + +trl: + use_vllm: false # Colocate mode (single GPU) + num_generations: 4 + max_completion_length: 128 + temperature: 0.9 + reward_funcs: + - axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify + +plugins: + - axolotl.integrations.nemo_gym.NemoGymPlugin + +nemo_gym_enabled: true +nemo_gym_dir: ~/Gym +nemo_gym_auto_start: false +nemo_gym_head_port: 11000 +nemo_gym_datasets: + - path: resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl + server_name: reasoning_gym + +datasets: + - path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl + type: chat_template + field_messages: responses_create_params.input + message_field_content: content + message_field_role: role +``` + +```bash +# Terminal 1: Start NeMo Gym resource server +cd ~/Gym && .venv/bin/ng_run \ + "+config_paths=[resources_servers/reasoning_gym/configs/resources_only.yaml]" \ + "+skip_venv_if_present=true" + +# Terminal 2: Train +CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml +``` + +::: {.callout-note} +`nemo_gym_datasets.path` is relative to `nemo_gym_dir`. Don't use absolute paths or they will be double-joined. +::: + +#### Multi-Turn with Async GRPO (Recommended) + +For environments with tool-use (weather, search, databases). An agent server orchestrates multi-turn interactions: generate → parse tool calls → execute tools → feed results back → repeat until done. + +```yaml +base_model: Qwen/Qwen3-0.6B + +rl: grpo +chat_template: tokenizer_default + +adapter: lora +lora_r: 16 +lora_alpha: 32 +lora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj] + +trl: + use_vllm: true + vllm_mode: server + vllm_server_host: localhost + vllm_server_port: 8000 + vllm_lora_sync: true + vllm_sync_interval: 5 + use_data_producer: true + async_prefetch: true # 3x speedup + num_generations: 4 + max_completion_length: 512 + temperature: 0.8 + reward_funcs: + - axolotl.integrations.nemo_gym.rewards.reward_env + +plugins: + - axolotl.integrations.nemo_gym.NemoGymPlugin + +nemo_gym_enabled: true +nemo_gym_auto_start: false +nemo_gym_head_port: 11000 +nemo_gym_multi_turn: true +nemo_gym_verify_timeout: 120 +nemo_gym_datasets: + - path: resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl + server_name: example_single_tool_call + +datasets: + - path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl + type: chat_template + field_messages: responses_create_params.input + message_field_content: content + message_field_role: role + +vllm: + gpu_memory_utilization: 0.85 + max_model_len: 2048 +``` + +Multi-turn requires three services running: + +```bash +# Terminal 1: vLLM with LoRA + tool calling +VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 CUDA_VISIBLE_DEVICES=0 \ + python -m vllm.entrypoints.openai.api_server \ + --model Qwen/Qwen3-0.6B --max-model-len 2048 \ + --gpu-memory-utilization 0.85 \ + --enable-lora --max-lora-rank 64 \ + --enable-auto-tool-choice --tool-call-parser hermes + +# Terminal 2: NeMo Gym servers (resource + model proxy + agent) +cd ~/Gym && .venv/bin/ng_run \ + "+config_paths=[configs/axolotl_tool_calling.yaml]" \ + "+skip_venv_if_present=true" + +# Terminal 3: Training +CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml +``` + +::: {.callout-important} +Multi-turn requires a NeMo Gym agent config YAML that defines three components: a resource server (tools + `/verify`), a model server proxy (forwards to your vLLM), and an agent server (orchestrates `/run`). See the [NeMo Gym README](https://github.com/NVIDIA-NeMo/Gym) for agent config format. +::: + +#### NeMo Gym Prerequisites + +```bash +# Clone and set up NeMo Gym +git clone https://github.com/NVIDIA-NeMo/Gym.git ~/Gym +cd ~/Gym +uv venv --python 3.12 && source .venv/bin/activate && uv sync + +# Fix pycosat build (GCC 13+) +CFLAGS="" uv pip install pycosat --python .venv/bin/python --no-build-isolation +``` + +#### NeMo Gym Configuration Reference + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `nemo_gym_enabled` | bool | — | Enable the NeMo Gym integration | +| `nemo_gym_dir` | str | `~/Gym` | Path to NeMo Gym repo | +| `nemo_gym_auto_start` | bool | `true` | Auto-start resource servers | +| `nemo_gym_head_port` | int | `11000` | Head server port | +| `nemo_gym_multi_turn` | bool | `false` | Enable multi-turn via agent `/run` | +| `nemo_gym_verify_timeout` | int | `30` | Per-request timeout (seconds) | +| `nemo_gym_datasets` | list | required | Dataset configs with `path` and `server_name` | + +#### Reward Functions + +| Function | Mode | Description | +|----------|------|-------------| +| `axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify` | Single-turn | Calls `/verify`, returns binary reward | +| `axolotl.integrations.nemo_gym.rewards.reward_env` | Multi-turn | Passthrough reward from agent `/run` | + ### Using local dataset files ```yaml diff --git a/examples/ebft/README.md b/examples/ebft/README.md index 533e13652..24f2c582d 100644 --- a/examples/ebft/README.md +++ b/examples/ebft/README.md @@ -47,14 +47,11 @@ For **unstructured text** without prompt/completion splits (e.g., raw code, pros ### Structured Mode (QA data + vLLM) ```bash -# 1. Start vLLM server -python -m trl.scripts.vllm_serve \ - --model meta-llama/Llama-3.2-1B \ - --host 0.0.0.0 --port 8000 \ - --gpu-memory-utilization 0.3 +# 1. Start vLLM server (LoRA serve module auto-selected when vllm_lora_sync: true) +CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen3-4b-ebft-structured-async.yaml -# 2. Train -axolotl train examples/ebft/llama-1b-ebft-opencode.yaml +# 2. Train on a separate GPU +CUDA_VISIBLE_DEVICES=1 axolotl train examples/ebft/qwen3-4b-ebft-structured-async.yaml ``` ### Strided Mode (unstructured text) diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index 89d4c9ff7..c5cdd3792 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -176,6 +176,22 @@ class HFRLTrainerBuilder(TrainerBuilderBase): ) training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() + if not async_grpo: + # Filter out async/fast-async-only fields not in standard GRPOConfig. + # These are defined in FastAsyncGRPOConfig and only used by + # AxolotlAsyncGRPOConfig. Standard GRPOConfig rejects them. + import dataclasses + + from trl import GRPOConfig as _BaseGRPOConfig + + from axolotl.core.trainers.grpo.fast_async_trainer import ( + FastAsyncGRPOConfig, + ) + + async_only_fields = { + f.name for f in dataclasses.fields(FastAsyncGRPOConfig) + } - {f.name for f in dataclasses.fields(_BaseGRPOConfig)} + blocklist_args_kwargs.extend(list(async_only_fields)) if self.cfg.rl is RLType.GDPO: training_args_kwargs.setdefault( "multi_objective_aggregation", "normalize_then_sum" diff --git a/src/axolotl/core/trainers/ebft/__init__.py b/src/axolotl/core/trainers/ebft/__init__.py index 23b61fbe6..92abe9f26 100644 --- a/src/axolotl/core/trainers/ebft/__init__.py +++ b/src/axolotl/core/trainers/ebft/__init__.py @@ -34,7 +34,16 @@ class EBFTStrategy: return AxolotlStridedEBFTTrainer # Structured mode: async or sync - use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False) + # use_data_producer also triggers async trainer (needed for LoRA sync + # without async_prefetch, since sync trainer lacks LoRA sync support) + use_async = ( + cfg + and cfg.trl + and ( + getattr(cfg.trl, "async_prefetch", False) + or getattr(cfg.trl, "use_data_producer", False) + ) + ) if use_async: from axolotl.core.trainers.ebft.trainer import AxolotlAsyncEBFTTrainer @@ -50,7 +59,14 @@ class EBFTStrategy: return AxolotlStridedEBFTConfig # Structured mode: async or sync config - use_async = cfg and cfg.trl and getattr(cfg.trl, "async_prefetch", False) + use_async = ( + cfg + and cfg.trl + and ( + getattr(cfg.trl, "async_prefetch", False) + or getattr(cfg.trl, "use_data_producer", False) + ) + ) if use_async: return AxolotlAsyncEBFTConfig return AxolotlEBFTConfig diff --git a/src/axolotl/core/trainers/grpo/async_trainer.py b/src/axolotl/core/trainers/grpo/async_trainer.py index 2b8cda6d8..9b6ae2e28 100644 --- a/src/axolotl/core/trainers/grpo/async_trainer.py +++ b/src/axolotl/core/trainers/grpo/async_trainer.py @@ -1012,27 +1012,44 @@ class AsyncGRPOTrainer(GRPOTrainer): import requests vllm_client = self.vllm_generation.vllm_client - url = f"{vllm_client.base_url}/set_lora_adapter/" + base_url = vllm_client.base_url + base_model = getattr(self.args, "model_name_or_path", "axolotl-lora") sync_timeout = getattr(self.args, "vllm_server_timeout", 300) or 300 + + # Try standard vLLM /v1/load_lora_adapter first, fall back to custom endpoint response = requests.post( - url, + f"{base_url}/v1/load_lora_adapter", json={ - "lora_name": "active_lora", - "lora_int_id": self._lora_sync_version, + "lora_name": base_model, "lora_path": adapter_path, + "load_inplace": True, }, timeout=sync_timeout, ) if response.status_code != 200: - logger.warning( - "Failed to set LoRA adapter: %s %s", - response.status_code, - response.text, + # Fallback: try custom /set_lora_adapter/ endpoint + response = requests.post( + f"{base_url}/set_lora_adapter/", + json={ + "lora_name": "active_lora", + "lora_int_id": self._lora_sync_version, + "lora_path": adapter_path, + }, + timeout=30, ) - return + if response.status_code != 200: + logger.warning( + "Failed to set LoRA adapter: %s %s", + response.status_code, + response.text, + ) + return # Reset prefix cache after adapter update - vllm_client.reset_prefix_cache() + try: + vllm_client.reset_prefix_cache() + except Exception as exc: + logger.warning("Failed to reset prefix cache: %s", exc) # Clean up old adapter versions (keep only current) if self._lora_sync_version > 1: @@ -2486,6 +2503,9 @@ class AsyncGRPOTrainer(GRPOTrainer): logits, completion_ids, self.temperature ) all_logps.append(logps) + # Liger fused path doesn't compute entropy — append zeros + if compute_entropy: + all_entropies.append(torch.zeros_like(logps)) else: logits = logits[:, :-1, :] logits = logits[:, -logits_to_keep:, :] diff --git a/src/axolotl/integrations/nemo_gym/README.md b/src/axolotl/integrations/nemo_gym/README.md new file mode 100644 index 000000000..2a1f267c1 --- /dev/null +++ b/src/axolotl/integrations/nemo_gym/README.md @@ -0,0 +1,412 @@ +# NeMo Gym Integration for Axolotl + +Train LLMs with reinforcement learning using [NVIDIA NeMo Gym](https://github.com/NVIDIA-NeMo/Gym) environments as reward sources. NeMo Gym provides 50+ verified RL environments spanning math, coding, tool-use, reasoning, and safety — each with deterministic reward signals. + +## Validated Training Paths + +| Path | Speed | Multi-turn | Architecture | +|------|-------|------------|--------------| +| **Async GRPO + Data Producer** | Fastest (3x) | Yes | `NemoGymDataProducer` replaces vLLM generation | +| Standard GRPO + Data Producer | Baseline | Yes | Same producer, no async prefetch | +| Standard GRPO + /verify | Simplest | No | Reward function calls /verify directly | +| FSDP2 + /verify (2 GPU) | Distributed | No | `fsdp_version: 2` | + +Multi-turn uses `nemo_gym_multi_turn: true` which auto-enables the async trainer's +data producer protocol. The plugin's `NemoGymDataProducer` calls NeMo Gym agent `/run` +endpoints and returns `RolloutDataset` with proper IS correction, env_mask, and rewards. + +All paths tested end-to-end with Qwen3-0.6B + LoRA, logged to wandb project `nemo-gym-rl`. + +## Quick Start + +### Prerequisites + +- [uv](https://github.com/astral-sh/uv) package manager (for NeMo Gym's venv) +- Two GPUs recommended (one for vLLM server, one for training) + +### 1. Set Up NeMo Gym + +```bash +git clone https://github.com/NVIDIA-NeMo/Gym.git ~/Gym +cd ~/Gym +uv venv --python 3.12 && source .venv/bin/activate && uv sync + +# Fix pycosat build (GCC 13+) +CFLAGS="" uv pip install pycosat --python .venv/bin/python --no-build-isolation + +# Pre-build resource server venvs +for dir in resources_servers/reasoning_gym resources_servers/example_single_tool_call responses_api_models/vllm_model responses_api_agents/simple_agent; do + uv venv --seed --allow-existing --python 3.12 $dir/.venv + CFLAGS="" uv pip install --python $dir/.venv/bin/python pycosat --no-build-isolation 2>/dev/null + uv pip install --python $dir/.venv/bin/python -e . "ray[default]==2.52.1" +done + +# Install extra deps for reasoning_gym +uv pip install --python resources_servers/reasoning_gym/.venv/bin/python \ + reasoning-gym matplotlib pillow cycler contourpy kiwisolver +``` + +### 2. Multi-Turn with Async GRPO (Recommended — Fastest Path) + +This is the fully validated, highest-performance path. NeMo Gym's agent server handles +multi-turn tool execution while axolotl's async GRPO prefetches data in background threads. + +**Step 1: Create the NeMo Gym agent config** + +Create `~/Gym/configs/axolotl_tool_calling.yaml`: +```yaml +# Resource server (tools + verify) +example_single_tool_call: + resources_servers: + example_single_tool_call: + entrypoint: app.py + domain: agent + verified: false + +# Model server proxy (forwards to your vLLM) +policy_model: + responses_api_models: + vllm_model: + entrypoint: app.py + base_url: http://localhost:8000/v1 + api_key: dummy_key + model: Qwen/Qwen3-0.6B # Must match your training model + return_token_id_information: true + uses_reasoning_parser: false + +# Agent server (orchestrates multi-turn via /run) +example_single_tool_call_simple_agent: + responses_api_agents: + simple_agent: + entrypoint: app.py + resources_server: + type: resources_servers + name: example_single_tool_call + model_server: + type: responses_api_models + name: policy_model + datasets: + - name: weather + type: example + jsonl_fpath: resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl +``` + +**Step 2: Start three services** + +```bash +# Terminal 1: vLLM OpenAI server on GPU 0 +CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \ + --model Qwen/Qwen3-0.6B --max-model-len 2048 --gpu-memory-utilization 0.85 + +# Terminal 2: NeMo Gym (resource server + model proxy + agent) +cd ~/Gym && .venv/bin/ng_run \ + "+config_paths=[configs/axolotl_tool_calling.yaml]" "+skip_venv_if_present=true" + +# Terminal 3: Training on GPU 1 +cd experiments && CUDA_VISIBLE_DEVICES=1 CUDA_HOME=$HOME/env-claude-cu130/cuda_shim \ + axolotl train nemo_gym_async_agent.yaml +``` + +**Step 3: Training config** (`nemo_gym_async_agent.yaml`): +```yaml +base_model: Qwen/Qwen3-0.6B +adapter: lora +lora_r: 16 +lora_alpha: 32 +lora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj] +sequence_len: 2048 + +rl: grpo +chat_template: tokenizer_default + +trl: + use_vllm: true + vllm_mode: server + vllm_server_host: localhost + vllm_server_port: 8000 + vllm_lora_sync: true + vllm_sync_interval: 5 + # Async GRPO — 3x faster than standard + use_data_producer: true + async_prefetch: true + num_generations: 4 + max_completion_length: 512 + temperature: 0.8 + reward_funcs: + - axolotl.integrations.nemo_gym.rewards.reward_env + +plugins: + - axolotl.integrations.nemo_gym.NemoGymPlugin + +nemo_gym_enabled: true +nemo_gym_auto_start: false +nemo_gym_head_port: 11000 +nemo_gym_multi_turn: true +nemo_gym_verify_timeout: 120 +nemo_gym_datasets: + - path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl + server_name: example_single_tool_call + +datasets: + - path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl + type: chat_template + field_messages: responses_create_params.input + message_field_content: content + message_field_role: role + +vllm: + gpu_memory_utilization: 0.85 + max_model_len: 2048 + tensor_parallel_size: 1 + +learning_rate: 5e-6 +micro_batch_size: 1 +gradient_accumulation_steps: 4 +max_steps: 30 +gradient_checkpointing: true +bf16: true +output_dir: ./outputs/nemo_gym_async + +use_wandb: true +wandb_project: nemo-gym-rl +``` + +### 3. Single-Turn Training (Simplest — No Agent Server Needed) + +For environments that only need single-turn verify (math, coding challenges), you don't need +an agent server. The plugin's reward function calls `/verify` directly. + +```yaml +base_model: Qwen/Qwen2.5-0.5B-Instruct +rl: grpo +chat_template: tokenizer_default + +trl: + use_vllm: true + vllm_mode: colocate + vllm_enable_sleep_mode: false + num_generations: 8 + max_completion_length: 128 + temperature: 0.9 + reward_funcs: + - axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify + +plugins: + - axolotl.integrations.nemo_gym.NemoGymPlugin + +nemo_gym_enabled: true +nemo_gym_auto_start: false +nemo_gym_head_port: 11000 +nemo_gym_datasets: + - path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl + server_name: reasoning_gym + +datasets: + - path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl + type: chat_template + field_messages: responses_create_params.input + message_field_content: content + message_field_role: role + +vllm: + gpu_memory_utilization: 0.3 + max_model_len: 512 + tensor_parallel_size: 1 + +learning_rate: 1e-5 +micro_batch_size: 4 +gradient_accumulation_steps: 2 +max_steps: 50 +output_dir: ./outputs/nemo_gym_arithmetic +``` + +Only needs `ng_run` with resource servers (no agent config): +```bash +cd ~/Gym && ng_run "+config_paths=[resources_servers/reasoning_gym/configs/resources_only.yaml]" "+skip_venv_if_present=true" +``` + +## How It Works + +### Single-Turn +```text +axolotl train → GRPO Trainer generates completions + → NeMo Gym plugin reward_fn calls POST /verify on resource server + → reward flows back to GRPO for advantage computation +``` + +### Multi-Turn (Agent /run) +```text +┌─────────────┐ ┌──────────────┐ ┌──────────────────┐ +│ axolotl │ │ NeMo Gym │────▶│ vLLM OpenAI │ +│ train │────▶│ Agent /run │◀────│ Server (GPU 0) │ +│ (GPU 1) │ │ │ │ /v1/completions │ +└─────────────┘ └──────┬───────┘ └──────────────────┘ + │ + ▼ + ┌──────────────┐ + │ Resource │ + │ Server │ + │ (tools + │ + │ verify) │ + └─────────────┘ +``` + +The agent server orchestrates the entire multi-turn loop: +1. Calls our vLLM server for model generation +2. Parses tool calls from model output +3. Executes tools against resource servers +4. Feeds tool results back to the model +5. Repeats until done, then calls /verify for reward +6. Returns token IDs + logprobs + reward to our rollout_func + +### Data Producer Architecture (Multi-Turn) + +When `nemo_gym_multi_turn: true`, the plugin automatically forces `use_data_producer: true` +which selects the `AxolotlAsyncGRPOTrainer`. The plugin then swaps the trainer's data +producer with `NemoGymDataProducer`, which: + +1. Gets a prompt batch from the dataset iterator +2. Expands by `num_generations` (one agent call per rollout) +3. Calls NeMo Gym agents via async HTTP (`aiohttp.gather`) +4. Parses responses into padded tensors (`RolloutDataset`) +5. Returns with `_pending_policy_logps=True` for deferred scoring + +The main thread then runs `_compute_deferred_scores()` which: +- Computes **policy logprobs** on the training model (GPU forward pass) +- Computes **IS correction** using agent's sampling logprobs vs training model logprobs +- Computes advantages with group-level normalization +- All downstream features work: replay buffer, re-roll, streaming, zero-adv skip + +With `async_prefetch: true`, the data producer runs in a background thread — giving ~3x +speedup as generation and training overlap. With `async_prefetch: false`, it runs +synchronously on the main thread (still uses the data producer protocol). + +### Weight Sync (LoRA Mode) + +With `vllm_lora_sync: true`, the plugin (or async trainer) replaces NCCL-based weight +sync with filesystem + HTTP: + +1. `accelerator.get_state_dict()` gathers LoRA weights from all ranks +2. Rank 0 saves adapter to `/tmp/lora_sync_*/vN/` +3. Rank 0 POSTs to `/set_lora_adapter/` on vLLM server +4. vLLM loads adapter natively via Punica kernels +5. Only ~40MB transferred (vs multiple GBs for full model weights) + +### Multi-Environment Support + +Datasets support per-row environment routing via `agent_ref`: +```jsonl +{"agent_ref": {"name": "reasoning_gym"}, "responses_create_params": {...}} +{"agent_ref": {"name": "instruction_following"}, "responses_create_params": {...}} +``` + +Or use the simpler per-dataset routing: +```yaml +nemo_gym_datasets: + - path: reasoning_data.jsonl + server_name: reasoning_gym + - path: tool_data.jsonl + server_name: example_single_tool_call +``` + +## Configuration Reference + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `nemo_gym_enabled` | bool | `null` | Enable the NeMo Gym integration | +| `nemo_gym_dir` | str | `~/Gym` | Path to NeMo Gym repo | +| `nemo_gym_auto_clone` | bool | `true` | Auto-clone NeMo Gym repo if missing | +| `nemo_gym_auto_start` | bool | `true` | Auto-start resource servers | +| `nemo_gym_config_paths` | list[str] | — | Server config YAMLs (relative to gym_dir) | +| `nemo_gym_datasets` | list[dict] | required | Dataset configs with `path` and optional `server_name` | +| `nemo_gym_head_port` | int | `11000` | Head server port | +| `nemo_gym_server_timeout` | int | `360` | Server startup timeout (seconds) | +| `nemo_gym_verify_timeout` | int | `30` | Per-request timeout (seconds) | +| `nemo_gym_multi_turn` | bool | `false` | Enable multi-turn via agent /run | + +### Dataset JSONL Format + +Each line must have `responses_create_params` with `input` messages: +```json +{ + "responses_create_params": { + "input": [{"role": "user", "content": "What's the weather in SF?"}], + "tools": [{"name": "get_weather", "type": "function", "strict": true, "parameters": {...}}] + } +} +``` + +For multi-turn agent routing, include `agent_ref`: +```json +{"agent_ref": {"name": "my_agent"}, "responses_create_params": {...}} +``` + +Note: Tool definitions MUST include `"strict": true` and `"additionalProperties": false` for NeMo Gym agent compatibility. + +### Reward Functions + +The plugin provides two built-in reward functions — no user code needed: + +```yaml +trl: + reward_funcs: + # Multi-turn (nemo_gym_multi_turn: true): + # Passthrough — agent /run already computed the reward + - axolotl.integrations.nemo_gym.rewards.reward_env + + # Single-turn (nemo_gym_multi_turn: false): + # Calls /verify endpoints on NeMo Gym resource servers + - axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify +``` + +Both are also importable from Python: + +```python +from axolotl.integrations.nemo_gym import reward_env, reward_nemo_gym_verify +``` + +## Known Issues / Troubleshooting + +### NeMo Gym Server Setup +- **pycosat build failure**: `CFLAGS="" uv pip install pycosat --no-build-isolation` +- **Ray version mismatch**: Pin `ray[default]==2.52.1` in all server venvs +- **Pre-build venvs**: `ng_run` creates per-server venvs via Ray. Pre-build them and use `+skip_venv_if_present=true` +- **Tool `strict` field required**: Agent server validates tool definitions require `strict: true` + +### vLLM / Weight Sync +- **Start vLLM with LoRA + tool calling + runtime loading**: + ```bash + VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 \ + CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \ + --model Qwen/Qwen3-4B-Instruct-2507 \ + --max-model-len 4096 \ + --gpu-memory-utilization 0.7 \ + --enable-lora --max-lora-rank 64 \ + --enable-auto-tool-choice --tool-call-parser hermes + ``` +- **`VLLM_ALLOW_RUNTIME_LORA_UPDATING=1`**: Required for `vllm_lora_sync: true`. Without it, vLLM won't expose the `/v1/load_lora_adapter` endpoint and weight sync will fail silently. The plugin warns if this endpoint is missing. +- **`--enable-lora`**: Enables LoRA adapter support in vLLM +- **`--enable-auto-tool-choice --tool-call-parser hermes`**: Required for Qwen3 tool calling +- **`max_model_len` must be > `max_completion_length`**: Leave room for prompt tokens (~200). If equal, the NeMo Gym model proxy gets a 400 error and returns empty completions. +- **`CUDA_HOME` required**: DeepSpeed import needs it for the nvcc shim +- **NCCL weight sync broken with vLLM 0.17**: Use `vllm_lora_sync: true` (filesystem + HTTP via `/v1/load_lora_adapter`) + +### Multi-Turn +- **Agent server required**: Multi-turn delegates to NeMo Gym's agent server `/run` endpoint. Without an agent, the plugin falls back to single-turn `/verify` +- **Model server proxy**: NeMo Gym needs a `responses_api_models` server that proxies to your vLLM. See the agent config example above + +### FSDP2 +- Validated on 2 GPUs with single-turn + LoRA +- Async field filtering: The builder automatically filters async-only config fields when using the standard GRPO trainer + +## Comparison with Other Integrations + +| Feature | Axolotl + NeMo Gym | Unsloth + NeMo Gym | NeMo RL (native) | +|---------|-------------------|-------------------|-------------------| +| Server management | Automatic | Manual (notebook) | Built-in | +| Multi-environment | Per-row routing | Manual code | YAML config | +| Multi-turn / tool use | Agent /run delegation | No | Agent /run (Ray) | +| Async GRPO (3x speedup) | Yes | No | Yes | +| LoRA sync | Filesystem + HTTP | N/A | NCCL | +| Multi-GPU (FSDP2) | Yes | No | Yes (Ray) | +| Config-driven | Yes | No (code) | Yes | diff --git a/src/axolotl/integrations/nemo_gym/__init__.py b/src/axolotl/integrations/nemo_gym/__init__.py new file mode 100644 index 000000000..b880bee7f --- /dev/null +++ b/src/axolotl/integrations/nemo_gym/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2026 Axolotl AI. All rights reserved. +# +# This software may be used and distributed according to +# the terms of the Axolotl Community License Agreement (the "License"); +# you may not use this file except in compliance with the License. + +""" +Plugin for NVIDIA NeMo Gym integration with Axolotl. + +NeMo Gym provides RL training environments for LLMs with verification-based +reward signals. This plugin manages the NeMo Gym server lifecycle, loads +datasets in the NeMo Gym JSONL format, and creates reward functions that +call the NeMo Gym /verify endpoints. +""" + +from .args import NemoGymArgs +from .plugin import NemoGymPlugin +from .rewards import reward_env, reward_nemo_gym_verify + +__all__ = [ + "NemoGymArgs", + "NemoGymPlugin", + "reward_env", + "reward_nemo_gym_verify", +] diff --git a/src/axolotl/integrations/nemo_gym/args.py b/src/axolotl/integrations/nemo_gym/args.py new file mode 100644 index 000000000..3c593fa80 --- /dev/null +++ b/src/axolotl/integrations/nemo_gym/args.py @@ -0,0 +1,146 @@ +# Copyright 2026 Axolotl AI. All rights reserved. +# +# This software may be used and distributed according to +# the terms of the Axolotl Community License Agreement (the "License"); +# you may not use this file except in compliance with the License. + +""" +Input arguments for the NeMo Gym integration plugin. +""" + +from pydantic import BaseModel, Field, model_validator + + +class NemoGymArgs(BaseModel): + """Configuration args for the NeMo Gym integration.""" + + nemo_gym_enabled: bool | None = Field( + default=True, + json_schema_extra={ + "description": "Enable NeMo Gym integration for environment-based RL rewards." + }, + ) + nemo_gym_dir: str | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Path to the NeMo Gym repository clone. " + "If not set and nemo_gym_auto_clone is True, clones to ~/Gym." + ) + }, + ) + nemo_gym_auto_clone: bool | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Automatically clone the NeMo Gym repository if not present. " + "Defaults to True when nemo_gym_enabled is set." + ) + }, + ) + nemo_gym_config_paths: list[str] | None = Field( + default=None, + json_schema_extra={ + "description": ( + "List of NeMo Gym resource server config YAML paths, relative to nemo_gym_dir. " + "Example: ['resources_servers/reasoning_gym/configs/resources_only.yaml']" + ) + }, + ) + nemo_gym_head_port: int | None = Field( + default=11000, + json_schema_extra={ + "description": "Port for the NeMo Gym head server. Defaults to 11000." + }, + ) + nemo_gym_server_timeout: int | None = Field( + default=360, + json_schema_extra={ + "description": "Timeout in seconds waiting for NeMo Gym servers to start. Defaults to 360." + }, + ) + nemo_gym_verify_timeout: int | None = Field( + default=30, + json_schema_extra={ + "description": "Timeout in seconds for individual /verify requests. Defaults to 30." + }, + ) + nemo_gym_run_timeout: int | None = Field( + default=300, + json_schema_extra={ + "description": ( + "Timeout in seconds for each agent /run request (one multi-turn rollout). " + "Prevents stuck generations (e.g. model looping on tags) from " + "blocking training indefinitely. Defaults to 300 (5 minutes)." + ) + }, + ) + nemo_gym_datasets: list[dict] | None = Field( + default=None, + json_schema_extra={ + "description": ( + "List of NeMo Gym dataset configs. Each entry has 'path' (JSONL file path " + "relative to nemo_gym_dir) and optionally 'server_name' (default resource server). " + "If the JSONL rows have agent_ref.name, that takes precedence per row, " + "enabling multi-environment training from a single dataset file. " + "Optional 'max_samples' to limit per dataset." + ) + }, + ) + nemo_gym_auto_start: bool | None = Field( + default=True, + json_schema_extra={ + "description": ( + "Automatically start NeMo Gym resource servers. Defaults to True. " + "Set to False if servers are already running externally." + ) + }, + ) + nemo_gym_model_name: str | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Model name to report in verify requests. " + "Defaults to the base_model from the main config." + ) + }, + ) + nemo_gym_multi_turn: bool | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Enable multi-turn rollouts via NeMo Gym. When True, uses TRL's " + "rollout_func to run multi-step interactions with tool execution. " + "Requires use_vllm=True in TRL config. The model generates responses, " + "tool calls are executed against resource servers, and results are " + "fed back for the next turn. Final reward comes from /verify." + ) + }, + ) + nemo_gym_max_turns: int | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Maximum number of turns per multi-turn rollout. Defaults to 10. " + "Each turn consists of a model generation + optional tool execution." + ) + }, + ) + + @model_validator(mode="before") + @classmethod + def check_nemo_gym_config(cls, data): + if data.get("nemo_gym_enabled"): + if not data.get("nemo_gym_config_paths") and data.get( + "nemo_gym_auto_start", True + ): + raise ValueError( + "nemo_gym_config_paths is required when nemo_gym_enabled=True " + "and nemo_gym_auto_start is not False." + ) + if not data.get("nemo_gym_datasets"): + raise ValueError( + "nemo_gym_datasets is required when nemo_gym_enabled=True. " + "Provide at least one dataset with 'path' and 'server_name'." + ) + return data diff --git a/src/axolotl/integrations/nemo_gym/data_producer.py b/src/axolotl/integrations/nemo_gym/data_producer.py new file mode 100644 index 000000000..64b76d780 --- /dev/null +++ b/src/axolotl/integrations/nemo_gym/data_producer.py @@ -0,0 +1,226 @@ +# Copyright 2026 Axolotl AI. All rights reserved. +# +# This software may be used and distributed according to +# the terms of the Axolotl Community License Agreement (the "License"); +# you may not use this file except in compliance with the License. + +""" +NeMo Gym Data Producer for async GRPO training. + +Replaces GRPODataProducer to generate rollouts via NeMo Gym agent /run endpoints +instead of vLLM. The agent handles generation, tool execution, and reward computation. +Returns RolloutDataset in the same format as the standard producer, so all downstream +components (deferred scoring, IS correction, streaming, replay, re-roll) work unchanged. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import torch +from trl.trainer.utils import pad + +from axolotl.core.trainers.grpo.async_trainer import GRPODataProducer, RolloutDataset +from axolotl.utils.logging import get_logger + +from .multi_turn import _call_agents, _parse_agent_response + +LOG = get_logger(__name__) + + +class NemoGymDataProducer(GRPODataProducer): + """Produces GRPO rollouts by calling NeMo Gym agent /run endpoints. + + Drop-in replacement for GRPODataProducer. Instead of calling vLLM for generation, + sends prompts to NeMo Gym agents which handle generation + tool execution + reward. + Returns the same RolloutDataset format so deferred scoring, IS correction, + replay buffer, and re-roll all work unchanged. + """ + + def __init__( + self, + *args, + agent_servers: dict[str, str], + dataset_lookup: dict, + request_timeout: float = 10800, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._agent_servers = agent_servers + self._dataset_lookup = dataset_lookup + self._request_timeout = request_timeout + + def produce( + self, + model: Any, + global_step: int, + *, + skip_policy_logps: bool = False, + processing_class: Any = None, + accelerator: Any = None, + args: Any = None, + _rank0_only: bool = False, + **kwargs, + ) -> RolloutDataset | None: + """Generate rollouts via NeMo Gym agents. + + Calls agent /run endpoints, parses responses into padded tensors, + and returns a RolloutDataset for deferred scoring on the main thread. + """ + trainer = self._trainer + is_main = trainer.accelerator.is_main_process + device = trainer.accelerator.device + + if _rank0_only and not is_main: + return None + + # Get prompt batch from iterator + try: + inputs = next(self._prompt_iter) + except StopIteration: + self._prompt_iter = iter(self._prompt_dl) + inputs = next(self._prompt_iter) + + # Extract dataset items for agent calls + dataset_items = [] + for inp in inputs: + prompt_text = "" + prompt = inp.get("prompt", []) + if isinstance(prompt, list) and prompt: + prompt_text = ( + prompt[-1].get("content", "") + if isinstance(prompt[-1], dict) + else str(prompt[-1]) + ) + elif isinstance(prompt, str): + prompt_text = prompt + + # Find the full dataset item, preserving agent_ref for routing + full_item = self._dataset_lookup.get(prompt_text, {}) + item = full_item.get("verify_extra", {}) + if not item: + item = { + "responses_create_params": { + "input": [{"role": "user", "content": prompt_text}] + } + } + # Preserve agent_ref from the dataset row for _call_agents routing + if "agent_ref" in full_item and "agent_ref" not in item: + item["agent_ref"] = full_item["agent_ref"] + dataset_items.append(item) + + # Expand by num_generations (agent produces one rollout per call) + expanded_items = [] + for item in dataset_items: + for _ in range(self._num_generations): + expanded_items.append(item) + + # Call NeMo Gym agents + loop = asyncio.new_event_loop() + try: + responses = loop.run_until_complete( + _call_agents( + dataset_items=expanded_items, + agent_servers=self._agent_servers, + timeout=self._request_timeout, + max_completion_length=trainer.max_completion_length, + temperature=trainer.temperature, + top_p=getattr(trainer, "top_p", None) or 0.999, + ) + ) + finally: + loop.close() + + # Parse responses + eos_token_id = trainer.processing_class.eos_token_id + prompt_ids_list = [] + completion_ids_list = [] + env_mask_list = [] + logprobs_list = [] + rewards_list = [] + + for resp in responses: + parsed = _parse_agent_response(resp, eos_token_id) + prompt_ids_list.append(parsed["prompt_ids"]) + completion_ids_list.append(parsed["completion_ids"]) + env_mask_list.append(parsed["env_mask"]) + logprobs_list.append(parsed["logprobs"]) + rewards_list.append(parsed["reward"]) + + # Pad to tensors + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad( + prompt_ids, padding_value=trainer.pad_token_id, padding_side="left" + ) + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + + completion_ids = [ + torch.tensor(ids, device=device) for ids in completion_ids_list + ] + completion_mask = [ + torch.ones_like(ids, dtype=torch.long) for ids in completion_ids + ] + completion_ids = pad( + completion_ids, padding_value=trainer.pad_token_id, padding_side="right" + ) + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + + # Sampling logprobs from agent (used for IS correction) + sampling_logps = [ + torch.tensor(lp, dtype=torch.float32, device=device) for lp in logprobs_list + ] + sampling_per_token_logps = pad( + sampling_logps, padding_value=0.0, padding_side="right" + ) + + # env_mask as tool_mask (1=model tokens, 0=tool tokens) + tool_mask = [torch.tensor(m, device=device) for m in env_mask_list] + tool_mask = pad(tool_mask, padding_value=1, padding_side="right") + + # Inject rewards into inputs so _compute_deferred_scores can use them + # The deferred scoring path calls _calculate_rewards which reads reward_funcs. + # Our passthrough reward_fn reads "env_reward" from kwargs. + for i, inp in enumerate(inputs): + # Each input gets rewards for its num_generations rollouts + start = i * self._num_generations + end = start + self._num_generations + inp["env_reward"] = rewards_list[start:end] + + # Expand inputs to match expanded rollouts (num_generations copies) + expanded_inputs = [] + for inp in inputs: + for g in range(self._num_generations): + expanded_inp = dict(inp) + expanded_inp["env_reward"] = inp["env_reward"][g] + expanded_inputs.append(expanded_inp) + + # Decode completions for reward functions + completions = trainer.processing_class.batch_decode( + completion_ids, skip_special_tokens=True + ) + + # Build total token count + num_items_in_batch = completion_mask.sum() + + # Build output dict (same shape as _generate_only) + output = { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "num_items_in_batch": num_items_in_batch, + "advantages": torch.zeros(completion_ids.size(0), device=device), + "sampling_per_token_logps": sampling_per_token_logps, + "tool_mask": tool_mask, + # Deferred scoring markers + "_pending_policy_logps": True, + "_deferred_inputs": expanded_inputs, + "_deferred_prompts": [inp.get("prompt", "") for inp in expanded_inputs], + "_deferred_completions": completions, + "_deferred_completion_ids_list": completion_ids_list, + "_rank0_only": _rank0_only, + } + + return RolloutDataset(output) diff --git a/src/axolotl/integrations/nemo_gym/dataset.py b/src/axolotl/integrations/nemo_gym/dataset.py new file mode 100644 index 000000000..be53da52d --- /dev/null +++ b/src/axolotl/integrations/nemo_gym/dataset.py @@ -0,0 +1,135 @@ +# Copyright 2026 Axolotl AI. All rights reserved. +# +# This software may be used and distributed according to +# the terms of the Axolotl Community License Agreement (the "License"); +# you may not use this file except in compliance with the License. + +""" +Dataset loading for NeMo Gym JSONL files. + +Converts NeMo Gym JSONL format into HuggingFace Datasets compatible +with TRL's GRPOTrainer. Supports multi-environment routing via: + 1. Per-dataset server_name (all rows in a file go to one server) + 2. Per-row agent_ref.name (each row specifies its own server) +""" + +from __future__ import annotations + +import json +import os +import random + +from datasets import Dataset + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def load_nemo_gym_datasets( + gym_dir: str, + dataset_configs: list[dict], +) -> Dataset: + """Load and merge NeMo Gym JSONL datasets with multi-environment support. + + Each dataset config should have: + - path: JSONL file path (absolute, or relative to gym_dir) + - server_name: Default NeMo Gym server for this dataset. + Can be overridden per-row if the JSONL has an "agent_ref" field. + - max_samples (optional): Max number of samples to use from this dataset + + Per-row routing: If a JSONL row has an "agent_ref": {"name": "..."} field, + that takes precedence over the dataset-level server_name. This allows mixing + environments within a single dataset file (matching TRL's pattern). + + The output dataset has columns: + - prompt: list[dict] chat format + - resources_server_ref: dict with {"name": server_name} + - verify_extra: dict with original JSONL data for verify requests + + Args: + gym_dir: Path to the NeMo Gym directory. + dataset_configs: List of dataset configuration dicts. + + Returns: + A HuggingFace Dataset ready for GRPOTrainer. + """ + all_examples = [] + + for ds_cfg in dataset_configs: + path = ds_cfg["path"] + default_server = ds_cfg.get("server_name", "") + max_samples = ds_cfg.get("max_samples") + + # Resolve path + if not os.path.isabs(path): + path = os.path.join(gym_dir, path) + path = os.path.expanduser(path) + + if not os.path.exists(path): + raise FileNotFoundError( + f"NeMo Gym dataset not found at {path}. " + "Ensure the dataset file exists or run the appropriate " + "NeMo Gym dataset creation script." + ) + + LOG.info( + f"Loading NeMo Gym dataset from {path} (default server: {default_server})" + ) + + with open(path, encoding="utf-8") as f: + lines = f.readlines() + + if max_samples and len(lines) > max_samples: + lines = random.sample(lines, max_samples) # nosec B311 + + for line in lines: + data = json.loads(line) + + # Extract user prompt from the input messages + inputs = data.get("responses_create_params", {}).get("input", []) + task_prompt = "" + for inp in inputs: + if isinstance(inp, dict) and inp.get("role") in ("user",): + task_prompt = inp.get("content", "") + break + if not task_prompt and inputs: + # Fallback: use the last input's content + task_prompt = ( + inputs[-1].get("content", "") + if isinstance(inputs[-1], dict) + else "" + ) + + # Per-row agent routing: agent_ref.name can override dataset-level server_name. + # NeMo Gym datasets may use agent names (e.g., "reasoning_gym_simple_agent") + # which differ from resource server names (e.g., "reasoning_gym"). + # The dataset-level server_name is always the fallback. + row_agent_ref = data.get("agent_ref", {}) + server_name = default_server + if row_agent_ref and row_agent_ref.get("name"): + # Use per-row name, but only if it looks like a resource server name. + # Agent names typically have "_simple_agent" or "_agent" suffix. + row_name = row_agent_ref["name"] + if row_agent_ref.get("type") != "responses_api_agents": + # Not an agent — could be a direct resource server reference + server_name = row_name + + all_examples.append( + { + "prompt": [{"role": "user", "content": task_prompt}], + "resources_server_ref": {"name": server_name}, + "verify_extra": data, + } + ) + + random.shuffle(all_examples) + + # Log environment distribution + env_counts: dict[str, int] = {} + for ex in all_examples: + name = ex["resources_server_ref"]["name"] + env_counts[name] = env_counts.get(name, 0) + 1 + LOG.info(f"Loaded {len(all_examples)} NeMo Gym examples: {env_counts}") + + return Dataset.from_list(all_examples) diff --git a/src/axolotl/integrations/nemo_gym/examples/nemo_gym_multi_env.yaml b/src/axolotl/integrations/nemo_gym/examples/nemo_gym_multi_env.yaml new file mode 100644 index 000000000..5a86f7578 --- /dev/null +++ b/src/axolotl/integrations/nemo_gym/examples/nemo_gym_multi_env.yaml @@ -0,0 +1,64 @@ +# Axolotl + NeMo Gym: Multi-Environment RL Training Example +# +# Trains on multiple NeMo Gym environments simultaneously: +# - Mini Sudoku (reasoning_gym) +# - Instruction Following +# +# Prerequisites: +# - uv (https://github.com/astral-sh/uv) installed +# - Download instruction_following.jsonl from nvidia/Nemotron-RL-instruction_following +# and place it at ~/Gym/resources_servers/instruction_following/data/instruction_following.jsonl +# +# Usage: +# axolotl train examples/nemo_gym_multi_env.yaml + +base_model: Qwen/Qwen2.5-1.5B-Instruct +model_type: AutoModelForCausalLM + +sequence_len: 4096 +load_in_4bit: false + +# RL configuration +rl: grpo +chat_template: tokenizer_default + +# GRPO / TRL settings +trl: + use_vllm: false + num_generations: 8 + max_completion_length: 2048 + +# NeMo Gym plugin +plugins: + - axolotl.integrations.nemo_gym.NemoGymPlugin + +nemo_gym_enabled: true +nemo_gym_dir: ~/Gym +nemo_gym_auto_clone: true +nemo_gym_auto_start: true +nemo_gym_config_paths: + - resources_servers/reasoning_gym/configs/resources_only.yaml + - resources_servers/instruction_following/configs/instruction_following.yaml +nemo_gym_datasets: + - path: resources_servers/reasoning_gym/data/train_mini_sudoku.jsonl + server_name: reasoning_gym + max_samples: 1000 + - path: resources_servers/instruction_following/data/instruction_following.jsonl + server_name: instruction_following + max_samples: 1000 + +# Training hyperparameters +learning_rate: 1.0e-5 +weight_decay: 0.001 +lr_scheduler: linear +warmup_ratio: 0.0 +optimizer: adamw_8bit + +num_epochs: 1 +max_steps: 100 +micro_batch_size: 1 +gradient_accumulation_steps: 64 + +logging_steps: 1 +save_steps: 100 +output_dir: ./outputs/nemo_gym_multi_env diff --git a/src/axolotl/integrations/nemo_gym/examples/nemo_gym_multi_turn.yaml b/src/axolotl/integrations/nemo_gym/examples/nemo_gym_multi_turn.yaml new file mode 100644 index 000000000..2644bbab7 --- /dev/null +++ b/src/axolotl/integrations/nemo_gym/examples/nemo_gym_multi_turn.yaml @@ -0,0 +1,79 @@ +# Axolotl + NeMo Gym: Multi-Turn Tool-Use RL Training Example +# +# This config trains a model with multi-turn reinforcement learning +# using NeMo Gym's agentic environments. The model generates responses, +# executes tool calls against NeMo Gym resource servers, and receives +# environment feedback across multiple turns. +# +# Multi-turn mode uses TRL's rollout_func with env_mask to ensure only +# model-generated tokens contribute to the training loss. +# +# Prerequisites: +# - uv (https://github.com/astral-sh/uv) installed +# - vLLM must be enabled (required for generate_rollout_completions) +# +# Usage: +# # Terminal 1: Start vLLM server +# trl vllm-serve --model Qwen/Qwen2.5-1.5B-Instruct +# +# # Terminal 2: Train +# axolotl train examples/nemo_gym_multi_turn.yaml + +base_model: Qwen/Qwen2.5-1.5B-Instruct +model_type: AutoModelForCausalLM + +sequence_len: 4096 +load_in_4bit: false + +# RL configuration +rl: grpo +chat_template: tokenizer_default + +# GRPO / TRL settings — vLLM required for multi-turn +trl: + use_vllm: true + vllm_mode: server + vllm_server_host: localhost + vllm_server_port: 8000 + num_generations: 4 + max_completion_length: 2048 + +# NeMo Gym plugin with multi-turn enabled +plugins: + - axolotl.integrations.nemo_gym.NemoGymPlugin + +nemo_gym_enabled: true +nemo_gym_dir: ~/Gym +nemo_gym_auto_clone: true +nemo_gym_auto_start: true +nemo_gym_head_port: 11000 +nemo_gym_server_timeout: 360 +nemo_gym_verify_timeout: 30 + +# Multi-turn settings +nemo_gym_multi_turn: true +nemo_gym_max_turns: 10 + +# Resource server configs — use environments that support tool calls +nemo_gym_config_paths: + - resources_servers/reasoning_gym/configs/resources_only.yaml +nemo_gym_datasets: + - path: resources_servers/reasoning_gym/data/train_mini_sudoku.jsonl + server_name: reasoning_gym + max_samples: 2000 + +# Training hyperparameters +learning_rate: 1.0e-5 +weight_decay: 0.001 +lr_scheduler: linear +warmup_ratio: 0.0 +optimizer: adamw_8bit + +num_epochs: 1 +max_steps: 100 +micro_batch_size: 1 +gradient_accumulation_steps: 64 + +logging_steps: 1 +save_steps: 100 +output_dir: ./outputs/nemo_gym_multi_turn diff --git a/src/axolotl/integrations/nemo_gym/examples/nemo_gym_sudoku.yaml b/src/axolotl/integrations/nemo_gym/examples/nemo_gym_sudoku.yaml new file mode 100644 index 000000000..df04b2b41 --- /dev/null +++ b/src/axolotl/integrations/nemo_gym/examples/nemo_gym_sudoku.yaml @@ -0,0 +1,62 @@ +# Axolotl + NeMo Gym: Sudoku RL Training Example +# +# This config trains a model to solve 4x4 Mini Sudoku puzzles using GRPO +# with NeMo Gym's reasoning_gym environment as the reward signal. +# +# Prerequisites: +# - uv (https://github.com/astral-sh/uv) installed +# - Git installed +# - The NeMo Gym repo will be auto-cloned to ~/Gym on first run +# +# Usage: +# axolotl train examples/nemo_gym_sudoku.yaml + +base_model: Qwen/Qwen2.5-1.5B-Instruct +model_type: AutoModelForCausalLM + +sequence_len: 4096 +load_in_4bit: false + +# RL configuration +rl: grpo +chat_template: tokenizer_default + +# GRPO / TRL settings +trl: + use_vllm: false + num_generations: 8 + max_completion_length: 2048 + +# NeMo Gym plugin +plugins: + - axolotl.integrations.nemo_gym.NemoGymPlugin + +nemo_gym_enabled: true +nemo_gym_dir: ~/Gym +nemo_gym_auto_clone: true +nemo_gym_auto_start: true +nemo_gym_head_port: 11000 +nemo_gym_server_timeout: 360 +nemo_gym_verify_timeout: 30 +nemo_gym_config_paths: + - resources_servers/reasoning_gym/configs/resources_only.yaml +nemo_gym_datasets: + - path: resources_servers/reasoning_gym/data/train_mini_sudoku.jsonl + server_name: reasoning_gym + max_samples: 2000 + +# Training hyperparameters +learning_rate: 1.0e-5 +weight_decay: 0.001 +lr_scheduler: linear +warmup_ratio: 0.0 +optimizer: adamw_8bit + +num_epochs: 1 +max_steps: 100 +micro_batch_size: 1 +gradient_accumulation_steps: 64 + +logging_steps: 1 +save_steps: 100 +output_dir: ./outputs/nemo_gym_sudoku diff --git a/src/axolotl/integrations/nemo_gym/multi_turn.py b/src/axolotl/integrations/nemo_gym/multi_turn.py new file mode 100644 index 000000000..328a393bc --- /dev/null +++ b/src/axolotl/integrations/nemo_gym/multi_turn.py @@ -0,0 +1,329 @@ +# Copyright 2026 Axolotl AI. All rights reserved. +# +# This software may be used and distributed according to +# the terms of the Axolotl Community License Agreement (the "License"); +# you may not use this file except in compliance with the License. + +""" +Multi-turn rollout function for NeMo Gym environments. + +Delegates multi-turn orchestration to NeMo Gym's agent servers via the /run +endpoint. The agent handles generation (by calling our vLLM server), tool +execution, session management, and reward computation. + +This follows the same pattern as TRL's reference implementation at +examples/scripts/nemo_gym/train_multi_environment.py. + +Architecture: + rollout_func(prompts, trainer) + -> expand prompts by num_generations + -> async POST /run to agent servers (one per sample) + -> parse response: prompt_ids, completion_ids, logprobs, env_mask, reward + -> return to TRL for GRPO training +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def create_nemo_gym_rollout_func( + agent_servers: dict[str, str], + dataset_lookup: dict[int, dict], + request_timeout: float = 10800, +): + """Create a TRL-compatible rollout_func that delegates to NeMo Gym agents. + + Args: + agent_servers: Mapping of agent_name → agent URL (e.g., {"simple_agent": "http://host:port"}). + dataset_lookup: Mapping of dataset index → full JSONL row dict. + request_timeout: HTTP timeout for /run requests. + + Returns: + A rollout_func with signature (prompts: list[str], trainer) -> dict. + """ + + def rollout_func(prompts: list[str], trainer) -> dict[str, Any]: + is_training = trainer.model.training + num_generations = ( + trainer.num_generations + if is_training + else getattr(trainer, "num_generations_eval", 1) + ) + temperature = trainer.temperature + top_p = getattr(trainer, "top_p", None) or 0.999 + max_completion_length = trainer.max_completion_length + eos_token_id = trainer.processing_class.eos_token_id + + # Expand prompts: each prompt index repeated num_generations times + expanded_items = [] + expanded_prompt_indices = [] + for prompt_str in prompts: + # Prompts from TRL are chat-templated strings. Find the dataset item + # by matching against dataset_lookup keys (raw user message text). + full_item = None + for key, val in dataset_lookup.items(): + if isinstance(key, str) and prompt_str == key: + full_item = val + break + + if full_item is None: + full_item = { + "responses_create_params": { + "input": [{"role": "user", "content": prompt_str}] + } + } + + for _ in range(num_generations): + # Preserve agent_ref for routing in _call_agents + dispatched: dict = full_item.get("verify_extra", full_item) # type: ignore[assignment] + if isinstance(dispatched, dict) and "agent_ref" not in dispatched: + agent_ref = full_item.get("agent_ref") + if agent_ref: + dispatched = {**dispatched, "agent_ref": agent_ref} + expanded_items.append(dispatched) + expanded_prompt_indices.append(prompt_str) + + # Call NeMo Gym agents + loop = asyncio.new_event_loop() + try: + responses = loop.run_until_complete( + _call_agents( + dataset_items=expanded_items, + agent_servers=agent_servers, + timeout=request_timeout, + max_completion_length=max_completion_length, + temperature=temperature, + top_p=top_p, + ) + ) + finally: + loop.close() + + # Parse responses into rollout format + all_prompt_ids = [] + all_completion_ids = [] + all_env_masks = [] + all_logprobs = [] + all_rewards = [] + all_num_turns = [] + + for _i, response in enumerate(responses): + result = _parse_agent_response(response, eos_token_id) + all_prompt_ids.append(result["prompt_ids"]) + all_completion_ids.append(result["completion_ids"]) + all_env_masks.append(result["env_mask"]) + all_logprobs.append(result["logprobs"]) + all_rewards.append(result["reward"]) + all_num_turns.append(result["num_turns"]) + + # TRL expects prompt_ids to be unique (one per original prompt, not per generation) + unique_prompt_ids = all_prompt_ids[::num_generations] + + # Wrap logprobs for TRL: list[list[list[float]]] + def _normalize(lp): + while isinstance(lp, (list, tuple)) and len(lp) > 0: + lp = lp[0] + return float(lp) if lp is not None else 0.0 + + wrapped_logprobs = [[[_normalize(lp)] for lp in seq] for seq in all_logprobs] + + return { + "prompt_ids": unique_prompt_ids, + "completion_ids": all_completion_ids, + "env_mask": all_env_masks, + "logprobs": wrapped_logprobs, + "logprob_token_ids": None, # nosec B105 + "env_reward": all_rewards, + "num_turns": all_num_turns, + } + + return rollout_func + + +async def _call_agents( + dataset_items: list[dict], + agent_servers: dict[str, str], + timeout: float, + max_completion_length: int = 4096, + temperature: float = 1.0, + top_p: float = 0.999, +) -> list[dict]: + """Async batch POST to NeMo Gym agent /run endpoints.""" + import aiohttp + + results = [] + connector = aiohttp.TCPConnector(limit_per_host=64, limit=256) + # Use sock_read for per-request timeout (not total session timeout). + # This ensures a single stuck generation doesn't block all other requests. + client_timeout = aiohttp.ClientTimeout(total=None, sock_read=timeout) + + async with aiohttp.ClientSession( + connector=connector, timeout=client_timeout, cookie_jar=aiohttp.DummyCookieJar() + ) as session: + tasks = [] + for item in dataset_items: + agent_ref = item.get("agent_ref", {}) + agent_name = agent_ref.get("name", "") + agent_url = agent_servers.get(agent_name, "") + + if not agent_url: + # Fallback: try first available agent + if agent_servers: + agent_url = next(iter(agent_servers.values())) + else: + results.append( + { + "response": {"output": []}, + "reward": 0.0, + "error": "No agent server", + } + ) + continue + + # Build request body + request_body = dict(item) + params = request_body.setdefault("responses_create_params", {}) + params.setdefault("max_output_tokens", max_completion_length) + params["temperature"] = temperature + params["top_p"] = top_p + + tasks.append(_post_run(session, agent_url, request_body)) + + if tasks: + responses = await asyncio.gather(*tasks, return_exceptions=True) + for resp in responses: + if isinstance(resp, BaseException): + LOG.warning(f"Agent /run failed: {resp}") + results.append( + {"response": {"output": []}, "reward": 0.0, "error": str(resp)} + ) + else: + results.append(resp) + + return results + + +async def _post_run(session, agent_url: str, body: dict) -> dict: + """POST to agent /run endpoint.""" + async with session.post(f"{agent_url}/run", json=body) as resp: + if resp.status == 200: + return await resp.json() + text = await resp.text() + return { + "response": {"output": []}, + "reward": 0.0, + "error": f"HTTP {resp.status}: {text[:200]}", + } + + +def _parse_agent_response(response: dict, eos_token_id: int) -> dict: + """Parse NeMo Gym agent /run response into rollout format. + + The agent returns: + response.output[]: list of turns, each with prompt_token_ids, + generation_token_ids, generation_log_probs + reward: float + """ + # Defaults for failed/empty responses + defaults = { + "prompt_ids": [eos_token_id], + "completion_ids": [eos_token_id], + "env_mask": [0], + "logprobs": [0.0], + "reward": 0.0, + "num_turns": 0, + } + + if not isinstance(response, dict) or "error" in response: + return defaults + + output_items = response.get("response", {}).get("output", []) + reward = float(response.get("reward", 0.0)) + + if not output_items: + defaults["reward"] = reward + return defaults + + # Check at least one valid output + has_valid = False + for item in output_items: + if item.get("type") == "function_call": + has_valid = True + break + if item.get("type") == "message": + for c in item.get("content", []): + if ( + isinstance(c, dict) + and c.get("type") == "output_text" + and c.get("text", "").strip() + ): + has_valid = True + break + if has_valid: + break + + if not has_valid: + defaults["reward"] = reward + return defaults + + # Extract multi-turn token sequences + first_prompt_ids = None + completion_ids = [] + env_mask = [] + logprobs = [] + seen_token_ids = [] + num_turns = 0 + + for item in output_items: + prompt_token_ids = item.get("prompt_token_ids", []) + generation_token_ids = item.get("generation_token_ids", []) + generation_log_probs = item.get("generation_log_probs", []) + + if not generation_token_ids: + continue + + num_turns += 1 + + # First turn: capture prompt + if first_prompt_ids is None: + first_prompt_ids = list(prompt_token_ids) + seen_token_ids = list(prompt_token_ids) + else: + # Subsequent turns: extract tool result tokens (between turns) + if len(prompt_token_ids) > len(seen_token_ids): + tool_result_tokens = prompt_token_ids[len(seen_token_ids) :] + # Tool result tokens are NOT trained on (env_mask = 0) + completion_ids.extend(tool_result_tokens) + env_mask.extend([0] * len(tool_result_tokens)) + logprobs.extend([0.0] * len(tool_result_tokens)) + + # Add generation tokens (trained on, env_mask = 1) + completion_ids.extend(generation_token_ids) + env_mask.extend([1] * len(generation_token_ids)) + + # Pad logprobs if shorter than generation tokens + gen_logprobs = list(generation_log_probs) if generation_log_probs else [] + if len(gen_logprobs) < len(generation_token_ids): + gen_logprobs.extend([0.0] * (len(generation_token_ids) - len(gen_logprobs))) + logprobs.extend(gen_logprobs[: len(generation_token_ids)]) + + # Update seen tokens + seen_token_ids = list(prompt_token_ids) + list(generation_token_ids) + + if first_prompt_ids is None: + return defaults + + return { + "prompt_ids": first_prompt_ids, + "completion_ids": completion_ids, + "env_mask": env_mask, + "logprobs": logprobs, + "reward": reward, + "num_turns": num_turns, + } diff --git a/src/axolotl/integrations/nemo_gym/plugin.py b/src/axolotl/integrations/nemo_gym/plugin.py new file mode 100644 index 000000000..14de684cf --- /dev/null +++ b/src/axolotl/integrations/nemo_gym/plugin.py @@ -0,0 +1,503 @@ +# Copyright 2026 Axolotl AI. All rights reserved. +# +# This software may be used and distributed according to +# the terms of the Axolotl Community License Agreement (the "License"); +# you may not use this file except in compliance with the License. + +""" +NeMo Gym Plugin for Axolotl. + +Integrates NVIDIA NeMo Gym environments as reward sources for GRPO training. +Handles server lifecycle, dataset loading, and reward function wiring. + +Supports two modes: + - Single-turn (default): reward_fn calls /verify after each generation + - Multi-turn (nemo_gym_multi_turn: true): rollout_func orchestrates + multi-step interactions with tool execution via resource servers +""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Union + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.logging import get_logger + +if TYPE_CHECKING: + from axolotl.common.datasets import TrainDatasetMeta + +LOG = get_logger(__name__) + + +class NemoGymPlugin(BasePlugin): + """Plugin for NVIDIA NeMo Gym integration with Axolotl. + + When enabled, this plugin: + 1. Clones and sets up the NeMo Gym repo (if needed) + 2. Starts NeMo Gym resource servers + 3. Loads datasets from NeMo Gym JSONL files + 4. For single-turn: creates a reward function calling /verify + 5. For multi-turn: creates a rollout_func with tool execution and env_mask + """ + + def __init__(self): + super().__init__() + self._gym_dir = None + self._global_config = None + self._verify_endpoints = None + self._server_base_urls = None + self._reward_fn = None + self._dataset_lookup = None + self._agent_servers = {} + + def get_input_args(self): + return "axolotl.integrations.nemo_gym.NemoGymArgs" + + def pre_model_load(self, cfg): + """Apply monkeypatches before trainer creation.""" + if not cfg.nemo_gym_enabled: + return + + # Always skip NCCL communicator init in NeMo Gym mode. + # NeMo Gym uses its own vLLM server (standard OpenAI API), not the TRL + # colocate/NCCL path. The NCCL init fails with vLLM V1 and standard servers. + trl_cfg = getattr(cfg, "trl", None) + if trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server": + self._patch_skip_nccl_init() + + def _patch_skip_nccl_init(self): + """Monkeypatch VLLMClient.init_communicator to no-op. + + NeMo Gym uses its own vLLM server (standard OpenAI API or custom LoRA + serve script). The NCCL communicator is not needed and fails with both + vLLM V1 engine and standard OpenAI server mode. + """ + try: + from trl.generation.vllm_client import VLLMClient + + VLLMClient._original_init_communicator = VLLMClient.init_communicator + VLLMClient.init_communicator = lambda self, **kwargs: LOG.info( + "Skipping NCCL init_communicator (LoRA sync mode)" + ) + LOG.info("Patched VLLMClient.init_communicator to no-op for LoRA sync") + except Exception as exc: + LOG.warning(f"Failed to patch VLLMClient: {exc}") + + def register(self, cfg): + if not cfg.get("nemo_gym_enabled"): + return + + LOG.info("NeMo Gym integration enabled") + gym_dir = cfg.get("nemo_gym_dir") or os.path.expanduser("~/Gym") + auto_clone = cfg.get("nemo_gym_auto_clone", True) + auto_start = cfg.get("nemo_gym_auto_start", True) + head_port = cfg.get("nemo_gym_head_port", 11000) + server_timeout = cfg.get("nemo_gym_server_timeout", 360) + + from .server import ( + ensure_gym_repo, + ensure_gym_venv, + get_agent_servers, + get_server_base_url, + get_server_configs, + get_verify_endpoint, + start_servers, + wait_for_resource_servers, + ) + + self._gym_dir = ensure_gym_repo(gym_dir, auto_clone=auto_clone) + + if auto_start: + config_paths = cfg.get("nemo_gym_config_paths", []) + ensure_gym_venv(self._gym_dir) + start_servers( + self._gym_dir, + config_paths, + head_port=head_port, + timeout=server_timeout, + ) + + self._global_config = get_server_configs(head_port=head_port) + wait_for_resource_servers(self._global_config, timeout=server_timeout) + + # Build endpoint maps for resource servers (/verify) + self._verify_endpoints = {} + self._server_base_urls = {} + for server_name in self._global_config: + try: + self._verify_endpoints[server_name] = get_verify_endpoint( + self._global_config, server_name + ) + self._server_base_urls[server_name] = get_server_base_url( + self._global_config, server_name + ) + except (ValueError, KeyError, TypeError): + pass + + # Discover agent servers (/run) for multi-turn + self._agent_servers = get_agent_servers(self._global_config) + + # Pre-build dataset lookup for multi-turn (needs to happen at register time, + # not load_datasets, because load_datasets may not be called if axolotl config + # has its own datasets field) + if cfg.get("nemo_gym_multi_turn") and cfg.get("nemo_gym_datasets"): + from .dataset import load_nemo_gym_datasets + + gym_dir = cfg.get("nemo_gym_dir") or os.path.expanduser("~/Gym") + dataset = load_nemo_gym_datasets(gym_dir, cfg["nemo_gym_datasets"]) + self._dataset_lookup = {} + for i in range(len(dataset)): + row = dataset[i] + # Use last message content as key (matches data_producer lookup) + prompt_text = row["prompt"][-1]["content"] + self._dataset_lookup[prompt_text] = row + LOG.info(f"Built dataset lookup with {len(self._dataset_lookup)} entries") + + multi_turn = cfg.get("nemo_gym_multi_turn", False) + LOG.info( + f"NeMo Gym ready with servers: {list(self._verify_endpoints.keys())} " + f"(multi_turn={'enabled' if multi_turn else 'disabled'})" + ) + + def load_datasets(self, cfg, preprocess=False) -> Union["TrainDatasetMeta", None]: + if not cfg.nemo_gym_enabled: + return None + + from axolotl.common.datasets import TrainDatasetMeta + + from .dataset import load_nemo_gym_datasets + + dataset_configs = cfg.nemo_gym_datasets + dataset = load_nemo_gym_datasets(self._gym_dir, dataset_configs) + + # Build prompt → row lookup for multi-turn rollout_func + # (rollout_func only receives prompt text, needs to look up row data) + self._dataset_lookup = {} + for i in range(len(dataset)): + row = dataset[i] + # Use last message content as key (matches data_producer lookup) + prompt_text = row["prompt"][-1]["content"] + self._dataset_lookup[prompt_text] = row + + return TrainDatasetMeta( + train_dataset=dataset, + eval_dataset=None, + total_num_steps=0, # computed later by the builder + ) + + def get_training_args(self, cfg): + """Pass through vLLM settings and force async trainer for multi-turn.""" + args = {} + # Pass vLLM settings from vllm config block to TRL training args + if cfg.vllm: + vllm_cfg = cfg.vllm + max_len = getattr(vllm_cfg, "max_model_len", None) + gpu_util = getattr(vllm_cfg, "gpu_memory_utilization", None) + tp_size = getattr(vllm_cfg, "tensor_parallel_size", None) + if max_len: + args["vllm_max_model_length"] = max_len + if gpu_util: + args["vllm_gpu_memory_utilization"] = gpu_util + if tp_size: + args["vllm_tensor_parallel_size"] = tp_size + + # Force async trainer for multi-turn: NemoGymDataProducer needs the + # data producer protocol. Setting use_data_producer=True selects + # AxolotlAsyncGRPOTrainer which supports _create_data_producer(). + # With async_prefetch=False this runs synchronously — no threading. + if cfg.nemo_gym_multi_turn and self._agent_servers: + args["use_data_producer"] = True + LOG.info( + "NeMo Gym multi-turn: forcing use_data_producer=True for data producer protocol" + ) + + # Dataloader workers fork subprocesses that can't handle the async + # HTTP connections to NeMo Gym agents. Force num_workers=0. + if getattr(cfg, "dataloader_num_workers", None) not in (None, 0): + LOG.warning( + f"NeMo Gym: overriding dataloader_num_workers={cfg.dataloader_num_workers} → 0 " + "(forked workers can't use NeMo Gym agent connections)" + ) + cfg.dataloader_num_workers = 0 + + if args: + LOG.info(f"NeMo Gym plugin injecting training args: {args}") + return args if args else None + + def post_trainer_create(self, cfg, trainer): + """Wire NeMo Gym into the trainer (reward_fn or rollout_func).""" + if not cfg.nemo_gym_enabled: + return + + model_name = cfg.nemo_gym_model_name or cfg.base_model or "axolotl-model" + verify_timeout = cfg.nemo_gym_verify_timeout or 30 + multi_turn = cfg.nemo_gym_multi_turn or False + + # Handle weight sync. NeMo Gym skips NCCL init, so we need to either: + # - Install LoRA sync (when vllm_lora_sync=True) + # - Or no-op sync_weights (when using standard vLLM server) + trl_cfg = getattr(cfg, "trl", None) + if hasattr(trainer, "vllm_generation") and trainer.vllm_generation: + vllm_gen = trainer.vllm_generation + if trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False): + self._setup_lora_sync(trainer) + # Verify the vLLM server supports runtime LoRA loading + self._check_lora_endpoint(vllm_gen) + else: + # No NCCL, no LoRA sync — skip all weight sync paths + vllm_gen.sync_weights = lambda: LOG.debug( + "Weight sync skipped (NeMo Gym mode)" + ) + type(vllm_gen).sync_weights = lambda self: LOG.debug( + "Weight sync skipped (NeMo Gym mode)" + ) + # Also patch the async trainer's internal sync method + if hasattr(trainer, "_maybe_sync_vllm_weights"): + trainer._maybe_sync_vllm_weights = lambda: LOG.debug( + "Async weight sync skipped (NeMo Gym mode)" + ) + LOG.info("Disabled weight sync (NeMo Gym mode, no LoRA sync)") + + if multi_turn: + self._wire_multi_turn(cfg, trainer, model_name, verify_timeout) + else: + self._wire_single_turn(trainer, model_name, verify_timeout) + + def _wire_single_turn(self, trainer, model_name, verify_timeout): + """Inject single-turn reward function into the trainer.""" + from .rewards import create_nemo_gym_reward_fn + + self._reward_fn = create_nemo_gym_reward_fn( + global_config=self._global_config, + verify_endpoints=self._verify_endpoints, + model_name=model_name, + verify_timeout=verify_timeout, + ) + + if hasattr(trainer, "reward_funcs"): + trainer.reward_funcs.append(self._reward_fn) + trainer.reward_func_names.append("nemo_gym") + trainer.reward_processing_classes.append(None) + LOG.info( + f"Added NeMo Gym reward function (single-turn). " + f"Total reward functions: {len(trainer.reward_funcs)}" + ) + else: + LOG.warning( + "Trainer does not have reward_funcs attribute. " + "NeMo Gym reward function not injected. " + "Ensure you are using a GRPO trainer." + ) + + def _wire_multi_turn(self, cfg, trainer, model_name, verify_timeout): + """Replace the data producer with NemoGymDataProducer. + + The plugin forces use_data_producer=True (in get_training_args) which + selects AxolotlAsyncGRPOTrainer. Here we swap its data_producer with + our NemoGymDataProducer that calls agent /run instead of vLLM generate. + """ + if not self._agent_servers: + LOG.warning( + "No NeMo Gym agent servers discovered. Multi-turn requires agent servers " + "started via ng_run with an agent config. Falling back to single-turn." + ) + self._wire_single_turn(trainer, model_name, verify_timeout) + return + + if not hasattr(trainer, "data_producer") or trainer.data_producer is None: + LOG.warning( + "Trainer has no data_producer. NeMo Gym multi-turn requires " + "use_data_producer=true (should be auto-set by plugin)." + ) + return + + from axolotl.core.trainers.grpo.async_trainer import AsyncDataProducer + + from .data_producer import NemoGymDataProducer + + # Get the current producer's config and params + current = trainer.data_producer + # Unwrap AsyncDataProducer to get the inner producer's config + if isinstance(current, AsyncDataProducer): + inner = current._inner + else: + inner = current + + nemo_producer = NemoGymDataProducer( + config=inner.config, + prompt_dataset=inner._dataset, + num_generations=inner._num_generations, + generation_batch_size=inner._generation_batch_size, + train_batch_size=inner._train_batch_size, + steps_per_generation=inner._steps_per_generation, + shuffle_dataset=inner._shuffle_dataset, + seed=inner._seed, + agent_servers=self._agent_servers, + dataset_lookup=self._dataset_lookup or {}, + request_timeout=float(cfg.nemo_gym_run_timeout or 300), + ) + nemo_producer.set_trainer(trainer) + + # Re-wrap in AsyncDataProducer if async prefetch is enabled + if getattr(trainer.args, "async_prefetch", False): + nemo_producer = AsyncDataProducer( + nemo_producer, + background_produce_kwargs={"skip_policy_logps": True}, + ) + + trainer.data_producer = nemo_producer + LOG.info( + f"NeMo Gym data producer installed " + f"(agent servers: {list(self._agent_servers.keys())}, " + f"async={'yes' if getattr(trainer.args, 'async_prefetch', False) else 'no'})" + ) + + # Passthrough reward function — agent /run already computed rewards + from .rewards import reward_env + + if hasattr(trainer, "reward_funcs"): + trainer.reward_funcs.append(reward_env) + trainer.reward_func_names.append("nemo_gym") + trainer.reward_processing_classes.append(None) + + @staticmethod + def _check_lora_endpoint(vllm_gen): + """Verify the vLLM server supports runtime LoRA loading.""" + import requests as http_requests + + if not hasattr(vllm_gen, "vllm_client") or vllm_gen.vllm_client is None: + return # Non-main rank in multi-GPU — client only exists on rank 0 + base_url = vllm_gen.vllm_client.base_url + try: + # Send a dummy load request — if the endpoint exists, we get a + # proper error (400/404 about the adapter), not a route 404. + resp = http_requests.post( + f"{base_url}/v1/load_lora_adapter", + json={"lora_name": "__probe__", "lora_path": "/nonexistent"}, + timeout=5, + ) + if ( + resp.status_code == 404 + and "Not Found" in resp.text + and "adapter" not in resp.text.lower() + ): + LOG.warning( + "vLLM server does not expose /v1/load_lora_adapter. " + "Set VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 when starting vLLM, e.g.:\n" + " VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 python -m vllm.entrypoints.openai.api_server " + "--enable-lora --max-lora-rank 64 ..." + ) + except Exception: + pass # Server might not be up yet, sync will warn later + + def _setup_lora_sync(self, trainer): + """Replace sync_weights with LoRA adapter sync via filesystem + HTTP. + + If the async trainer is detected (has ``_sync_lora_adapter``), delegates + to it — that method already handles multi-GPU (FSDP/DeepSpeed state_dict + gather, broadcast sync dir, barrier). + + Otherwise installs a standalone closure for the non-async GRPO path that + saves the adapter and POSTs to ``/v1/load_lora_adapter``. + """ + vllm_gen = trainer.vllm_generation + + # Async trainer path: delegate to its _sync_lora_adapter (multi-GPU safe) + if hasattr(trainer, "_sync_lora_adapter"): + + def lora_sync_weights(): + trainer._sync_lora_adapter() + + vllm_gen.sync_weights = lora_sync_weights + type(vllm_gen).sync_weights = lambda self: lora_sync_weights() + LOG.info( + "Installed LoRA adapter sync " + "(delegates to async trainer._sync_lora_adapter)" + ) + return + + # Non-async standard GRPO path: standalone closure + import os + import shutil + import tempfile + + import requests as http_requests + + base_model = getattr(trainer.args, "model_name_or_path", None) or "axolotl-lora" + sync_state = {"version": 0, "sync_dir": tempfile.mkdtemp(prefix="lora_sync_")} + + def lora_sync_weights(): + """Save LoRA adapter and load it into vLLM.""" + accelerator = vllm_gen.accelerator + model = vllm_gen.model + + if vllm_gen.mode != "server": + return + + sync_state["version"] += 1 + version = sync_state["version"] + adapter_path = os.path.join(sync_state["sync_dir"], f"v{version}") + + wrapped_model = getattr(trainer, "model_wrapped", model) + state_dict = accelerator.get_state_dict(wrapped_model) + + if accelerator.is_main_process: + unwrapped = accelerator.unwrap_model(model) + unwrapped.save_pretrained(adapter_path, state_dict=state_dict) + + base_url = vllm_gen.vllm_client.base_url + resp = http_requests.post( + f"{base_url}/v1/load_lora_adapter", + json={ + "lora_name": base_model, + "lora_path": adapter_path, + "load_inplace": True, + }, + timeout=30, + ) + if resp.status_code != 200: + resp = http_requests.post( + f"{base_url}/set_lora_adapter/", + json={ + "lora_name": "active_lora", + "lora_int_id": version, + "lora_path": adapter_path, + }, + timeout=30, + ) + if resp.status_code != 200: + LOG.warning( + f"Failed to set LoRA adapter: " + f"{resp.status_code} {resp.text}" + ) + return + + try: + vllm_gen.vllm_client.reset_prefix_cache() + except Exception as exc: + LOG.warning("Failed to reset prefix cache: %s", exc) + + if version > 1: + old = os.path.join(sync_state["sync_dir"], f"v{version - 1}") + if os.path.exists(old): + shutil.rmtree(old, ignore_errors=True) + + LOG.info(f"Synced LoRA adapter v{version} to vLLM ({adapter_path})") + + if accelerator.num_processes > 1: + import torch.distributed as dist + + if dist.is_initialized(): + dist.barrier() + + vllm_gen.sync_weights = lora_sync_weights + type(vllm_gen).sync_weights = lambda self: lora_sync_weights() + LOG.info("Installed LoRA adapter sync (standalone fallback)") + + def post_train_unload(self, cfg): + """Cleanup NeMo Gym servers if we started them.""" + if cfg.get("nemo_gym_enabled") and cfg.get("nemo_gym_auto_start", True): + from .server import _cleanup_servers + + _cleanup_servers() diff --git a/src/axolotl/integrations/nemo_gym/rewards.py b/src/axolotl/integrations/nemo_gym/rewards.py new file mode 100644 index 000000000..e99df0981 --- /dev/null +++ b/src/axolotl/integrations/nemo_gym/rewards.py @@ -0,0 +1,274 @@ +# Copyright 2026 Axolotl AI. All rights reserved. +# +# This software may be used and distributed according to +# the terms of the Axolotl Community License Agreement (the "License"); +# you may not use this file except in compliance with the License. + +""" +NeMo Gym reward functions. + +Provides ready-to-use reward functions for axolotl configs:: + + trl: + reward_funcs: + # Multi-turn: passthrough reward from agent /run + - axolotl.integrations.nemo_gym.rewards.reward_env + # Single-turn: call /verify endpoints directly + - axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import requests + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +# --------------------------------------------------------------------------- +# Multi-turn passthrough reward +# --------------------------------------------------------------------------- + + +def reward_env(completions, prompts=None, **kwargs): + """Passthrough: extract pre-computed reward from NeMo Gym agent /run response. + + The ``NemoGymDataProducer`` injects ``env_reward`` into each sample's + kwargs after the agent returns from ``/run``. This function simply + forwards that value so TRL can log it alongside other reward signals. + + Use this in your config when ``nemo_gym_multi_turn: true``:: + + trl: + reward_funcs: + - axolotl.integrations.nemo_gym.rewards.reward_env + """ + env_rewards = kwargs.get("env_reward") + if env_rewards is not None: + if isinstance(env_rewards, (list, tuple)): + return [float(r) for r in env_rewards] + return [float(env_rewards) for _ in completions] + return [0.0 for _ in completions] + + +# --------------------------------------------------------------------------- +# Single-turn /verify reward +# --------------------------------------------------------------------------- + +# Module-level cache for discovered verify URLs +_verify_urls: dict[str, str] = {} +_verify_urls_lock = __import__("threading").Lock() + + +def _get_verify_urls(head_port: int = 11000) -> dict[str, str]: + """Discover verify endpoints from the NeMo Gym head server. + + Results are cached so that the HTTP round-trip only happens once per + process. A lock guards against concurrent discovery from multiple + threads (e.g. async_prefetch background thread + main training thread). + """ + global _verify_urls + if _verify_urls: + return _verify_urls + + with _verify_urls_lock: + # Double-check after acquiring lock + if _verify_urls: + return _verify_urls + + import yaml + + try: + resp = requests.get( + f"http://127.0.0.1:{head_port}/global_config_dict_yaml", timeout=5 + ) + config = yaml.safe_load(resp.text) + if isinstance(config, str): + config = yaml.safe_load(config) + for _name, cfg in config.items(): + if not isinstance(cfg, dict): + continue + for srv_name, srv_cfg in cfg.get("resources_servers", {}).items(): + if ( + isinstance(srv_cfg, dict) + and "host" in srv_cfg + and "port" in srv_cfg + ): + _verify_urls[srv_name] = ( + f"http://{srv_cfg['host']}:{srv_cfg['port']}/verify" + ) + except Exception as exc: + LOG.warning(f"Failed to discover NeMo Gym verify endpoints: {exc}") + + return _verify_urls + + +def reward_nemo_gym_verify(completions, prompts=None, **kwargs): + """Call NeMo Gym ``/verify`` endpoint for each completion (single-turn). + + Requires ``resources_server_ref`` and ``verify_extra`` kwargs, which the + NeMo Gym dataset loader injects automatically. + + Use this in your config when ``nemo_gym_multi_turn: false``:: + + trl: + reward_funcs: + - axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify + """ + verify_urls = _get_verify_urls() + refs = kwargs.get("resources_server_ref", []) + extras = kwargs.get("verify_extra", []) + scores = [] + + for i, completion in enumerate(completions): + text = completion[0]["content"] if completion else "" + prompt = prompts[i][0]["content"] if prompts and i < len(prompts) else "" + srv_name = ( + refs[i]["name"] if i < len(refs) and isinstance(refs[i], dict) else "" + ) + url = verify_urls.get(srv_name, "") + + if not url: + scores.append(0.0) + continue + + extra = extras[i] if i < len(extras) else {} + req = {k: v for k, v in extra.items() if v is not None} + req["responses_create_params"] = { + "input": [{"role": "user", "content": prompt}] + } + req["response"] = { + "id": "resp", + "created_at": 0, + "model": "axolotl", + "object": "response", + "output": [ + { + "id": "msg", + "role": "assistant", + "type": "message", + "status": "completed", + "content": [ + {"type": "output_text", "text": text, "annotations": []} + ], + } + ], + "parallel_tool_calls": True, + "tool_choice": "auto", + "tools": [], + } + + try: + resp = requests.post(url, json=req, timeout=30) + reward = resp.json().get("reward", 0.0) if resp.ok else 0.0 + except Exception as exc: + LOG.warning(f"Verify request to {url} failed: {exc}") + reward = 0.0 + + scores.append(float(reward)) + + return scores + + +# --------------------------------------------------------------------------- +# Factory used internally by the plugin +# --------------------------------------------------------------------------- + + +def create_nemo_gym_reward_fn( + global_config: dict, + verify_endpoints: dict[str, str], + model_name: str = "axolotl-model", + verify_timeout: int = 30, +): + """Create a reward function bound to specific verify endpoints. + + Used internally by ``NemoGymPlugin._wire_single_turn()`` to inject a + reward function that already knows the endpoint map (no discovery needed). + """ + + def reward_fn( + completions: list[list[dict[str, str]]], + prompts: list[list[dict[str, str]]] | None = None, + **kwargs: Any, + ) -> np.ndarray: + resources_server_refs = kwargs.get("resources_server_ref", []) + verify_extras = kwargs.get("verify_extra", []) + + scores = [] + for i, completion in enumerate(completions): + completion_text = completion[0]["content"] + task_prompt = prompts[i][0]["content"] if prompts else "" + + server_name = ( + resources_server_refs[i]["name"] + if i < len(resources_server_refs) + else None + ) + + if server_name is None or server_name not in verify_endpoints: + LOG.warning( + f"No verify endpoint for server '{server_name}', returning 0 reward" + ) + scores.append(0.0) + continue + + verify_endpoint = verify_endpoints[server_name] + + verify_request = ( + {k: v for k, v in verify_extras[i].items() if v is not None} + if i < len(verify_extras) + else {} + ) + verify_request["responses_create_params"] = { + "input": [{"role": "user", "content": task_prompt}] + } + verify_request["response"] = { + "id": "resp", + "created_at": 0, + "model": model_name, + "object": "response", + "output": [ + { + "id": "msg", + "role": "assistant", + "type": "message", + "status": "completed", + "content": [ + { + "type": "output_text", + "text": completion_text, + "annotations": [], + } + ], + } + ], + "parallel_tool_calls": True, + "tool_choice": "auto", + "tools": [], + } + + try: + resp = requests.post( + verify_endpoint, json=verify_request, timeout=verify_timeout + ) + if resp.status_code == 200: + reward = resp.json().get("reward", 0.0) + else: + LOG.warning( + f"Verify request returned status {resp.status_code}: {resp.text[:200]}" + ) + reward = 0.0 + except requests.exceptions.RequestException as exc: + LOG.warning(f"Verify request failed: {exc}") + reward = 0.0 + + scores.append(float(reward)) + + return np.array(scores) + + return reward_fn diff --git a/src/axolotl/integrations/nemo_gym/server.py b/src/axolotl/integrations/nemo_gym/server.py new file mode 100644 index 000000000..0af9b3b71 --- /dev/null +++ b/src/axolotl/integrations/nemo_gym/server.py @@ -0,0 +1,242 @@ +# Copyright 2026 Axolotl AI. All rights reserved. +# +# This software may be used and distributed according to +# the terms of the Axolotl Community License Agreement (the "License"); +# you may not use this file except in compliance with the License. + +""" +NeMo Gym server lifecycle management. + +Handles cloning the NeMo Gym repo, starting resource servers, +waiting for readiness, and cleanup on exit. +""" + +from __future__ import annotations + +import atexit +import os +import subprocess # nosec B404 +import time + +import requests +import yaml + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +_ng_process = None +_ng_log_file = None + + +def ensure_gym_repo(gym_dir: str, auto_clone: bool = True) -> str: + """Clone the NeMo Gym repo if it doesn't exist. + + Args: + gym_dir: Path to the NeMo Gym directory. + auto_clone: Whether to auto-clone if missing. + + Returns: + Resolved path to the NeMo Gym directory. + """ + gym_dir = os.path.expanduser(gym_dir) + if os.path.exists(gym_dir): + LOG.info(f"NeMo Gym directory exists at {gym_dir}") + return gym_dir + + if not auto_clone: + raise FileNotFoundError( + f"NeMo Gym directory not found at {gym_dir} and auto_clone is disabled." + ) + + LOG.info(f"Cloning NeMo Gym to {gym_dir}...") + subprocess.run( # nosec + ["git", "clone", "https://github.com/NVIDIA-NeMo/Gym.git", gym_dir], + check=True, + ) + return gym_dir + + +def ensure_gym_venv(gym_dir: str): + """Set up the NeMo Gym Python venv if not present.""" + venv_python = os.path.join(gym_dir, ".venv", "bin", "python") + if os.path.exists(venv_python): + return + + LOG.info("Setting up NeMo Gym venv...") + subprocess.run(["uv", "venv", "--python", "3.12"], cwd=gym_dir, check=True) # nosec + subprocess.run( # nosec + ["bash", "-c", "source .venv/bin/activate && uv sync"], + cwd=gym_dir, + check=True, + ) + + +def start_servers( + gym_dir: str, + config_paths: list[str], + head_port: int = 11000, + timeout: int = 360, +): + """Start NeMo Gym resource servers via ng_run. + + Args: + gym_dir: Path to the NeMo Gym directory. + config_paths: List of config YAML paths relative to gym_dir. + head_port: Port for the head server. + timeout: Max seconds to wait for servers. + """ + global _ng_process, _ng_log_file + + head_url = f"http://127.0.0.1:{head_port}/global_config_dict_yaml" + + # Check if already running + try: + requests.get(head_url, timeout=2) + LOG.info(f"NeMo Gym servers already running on port {head_port}.") + return + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): + pass + + ng_run_bin = os.path.join(gym_dir, ".venv", "bin", "ng_run") + config_arg = f"+config_paths=[{','.join(config_paths)}]" + _ng_log_file = open(os.path.join(gym_dir, "ng_run.log"), "w") # noqa: SIM115 + _ng_process = subprocess.Popen( # nosec B603 + [ng_run_bin, config_arg, "+skip_venv_if_present=true"], + cwd=gym_dir, + stdout=_ng_log_file, + stderr=subprocess.STDOUT, + ) + + atexit.register(_cleanup_servers) + + LOG.info("Waiting for NeMo Gym head server...") + for _ in range(timeout // 3): + try: + requests.get(head_url, timeout=2) + LOG.info("NeMo Gym head server is ready.") + return + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): + if _ng_process.poll() is not None: + raise RuntimeError( + "NeMo Gym server process exited unexpectedly. " + f"Check {gym_dir}/ng_run.log for details." + ) from None + time.sleep(3) + + raise RuntimeError( + f"NeMo Gym servers did not start within {timeout}s. " + f"Check {gym_dir}/ng_run.log for details." + ) + + +def get_server_configs(head_port: int = 11000) -> dict: + """Fetch the global config from the NeMo Gym head server. + + Returns: + Dict mapping server_name -> server config. + """ + response = requests.get( + f"http://127.0.0.1:{head_port}/global_config_dict_yaml", timeout=5 + ) + response.raise_for_status() + result = yaml.safe_load(response.text) + # NeMo Gym head server double-encodes: YAML string inside a YAML string + if isinstance(result, str): + result = yaml.safe_load(result) + return result + + +def get_agent_servers( + global_config: dict, head_host: str = "127.0.0.1" +) -> dict[str, str]: + """Discover NeMo Gym agent servers from the global config. + + Agent servers handle multi-turn orchestration via /run endpoint. + Returns mapping of agent_name → URL (e.g., {"simple_agent": "http://host:port"}). + """ + agents = {} + for top_name, top_cfg in global_config.items(): + if not isinstance(top_cfg, dict): + continue + agent_dict = top_cfg.get("responses_api_agents", {}) + if not agent_dict: + continue + for _agent_name, agent_cfg in agent_dict.items(): + if not isinstance(agent_cfg, dict): + continue + host = agent_cfg.get("host", "127.0.0.1") + port = agent_cfg.get("port") + if not port: + continue + # Replace loopback with head_host for remote access + host = _normalize_host(host, fallback=head_host) + # Use the top-level config name (not the inner agent name) + # because dataset agent_ref.name references the top-level name + agents[top_name] = f"http://{host}:{port}" + if agents: + LOG.info(f"Discovered NeMo Gym agent servers: {agents}") + return agents + + +def _normalize_host(host: str, fallback: str = "127.0.0.1") -> str: + """Normalize bind-all and loopback addresses for reachability.""" + if host in ("0.0.0.0", "localhost"): # nosec B104 + return fallback + return host + + +def get_server_base_url(global_config: dict, server_name: str) -> str: + """Get the base URL for a given resource server.""" + try: + srv_cfg = global_config[server_name]["resources_servers"][server_name] + host = _normalize_host(srv_cfg["host"]) + return f"http://{host}:{srv_cfg['port']}" + except (KeyError, TypeError) as exc: + raise ValueError( + f"Could not find resource server config for '{server_name}' in NeMo Gym. " + f"Available servers: {list(global_config.keys())}" + ) from exc + + +def get_verify_endpoint(global_config: dict, server_name: str) -> str: + """Get the /verify endpoint URL for a given resource server.""" + return f"{get_server_base_url(global_config, server_name)}/verify" + + +def wait_for_resource_servers(global_config: dict, timeout: int = 180): + """Wait for all resource servers in the config to become reachable.""" + for srv_name in global_config: + try: + srv_cfg = global_config[srv_name]["resources_servers"][srv_name] + except (KeyError, TypeError): + continue # Skip non-server config entries silently + + host, port = _normalize_host(srv_cfg["host"]), srv_cfg["port"] + LOG.info(f"Waiting for resource server '{srv_name}' at {host}:{port}...") + for _ in range(timeout // 2): + try: + requests.get(f"http://{host}:{port}/", timeout=2) + LOG.info(f"Resource server '{srv_name}' is ready.") + break + except requests.exceptions.ConnectionError: + time.sleep(2) + else: + raise RuntimeError( + f"Resource server '{srv_name}' at {host}:{port} " + f"did not start within {timeout}s." + ) + + +def _cleanup_servers(): + """Terminate NeMo Gym server process on exit.""" + global _ng_process, _ng_log_file + if _ng_process is not None and _ng_process.poll() is None: + LOG.info("Terminating NeMo Gym servers...") + _ng_process.terminate() + try: + _ng_process.wait(timeout=10) + except subprocess.TimeoutExpired: + _ng_process.kill() + if _ng_log_file is not None: + _ng_log_file.close() diff --git a/src/axolotl/scripts/vllm_serve_lora.py b/src/axolotl/scripts/vllm_serve_lora.py index f4fcfa190..e292d89f8 100644 --- a/src/axolotl/scripts/vllm_serve_lora.py +++ b/src/axolotl/scripts/vllm_serve_lora.py @@ -446,6 +446,114 @@ def main(script_args: ScriptArguments): "logprob_token_ids": extract_logprobs(all_outputs)[1], } + # --- OpenAI-compatible endpoints (for NeMo Gym agent integration) --- + + @app.get("/v1/models") + async def list_models(): + """OpenAI-compatible models endpoint.""" + return { + "object": "list", + "data": [ + {"id": script_args.model, "object": "model", "owned_by": "axolotl"} + ], + } + + @app.post("/v1/chat/completions") + async def openai_chat_completions(request_body: dict): + """OpenAI-compatible chat completions endpoint. + + Translates OpenAI format to our internal /chat/ format so NeMo Gym's + model server proxy can call us directly. + """ + messages_list = request_body.get("messages", []) + temperature = request_body.get("temperature", 1.0) + max_tokens = request_body.get("max_tokens", 512) + top_p = request_body.get("top_p", 1.0) + n = request_body.get("n", 1) + + generation_kwargs = { + "n": n, + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + "logprobs": 0, # Always return logprobs (NeMo Gym needs them) + } + sampling_params = SamplingParams( + **{k: v for k, v in generation_kwargs.items() if v is not None} + ) + + # Send to vLLM worker + chunked = chunk_list([messages_list], script_args.data_parallel_size) + for conn, chunk in zip(connections, chunked, strict=True): + if not chunk: + chunk = [[{"role": "user", "content": ""}]] + kwargs = { + "messages": chunk, + "sampling_params": sampling_params, + "use_tqdm": False, + "lora_request": active_lora["request"], + } + conn.send({"type": "call", "method": "chat", "kwargs": kwargs}) + + all_outputs = [conn.recv() for conn in connections] + all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c] + all_outputs = list(chain.from_iterable(all_outputs)) + + if not all_outputs: + return {"choices": [], "model": script_args.model} + + # Format as OpenAI response + import uuid + + choices = [] + for i, output in enumerate(all_outputs): + for j, out in enumerate(output.outputs): + text = out.text + # Extract token IDs if requested + # Build logprobs in OpenAI format + lp_list = None + if out.logprobs: + lp_list = { + "content": [ + {"token": "", "logprob": next(iter(lp.values())).logprob} # nosec B105 + for lp in out.logprobs + ] + } + + choice = { + "index": i * n + j, + "message": {"role": "assistant", "content": text}, + "finish_reason": "stop" + if out.finish_reason == "stop" + else "length", + "logprobs": lp_list, + } + # Include token ID information for NeMo Gym + choice["prompt_token_ids"] = output.prompt_token_ids + choice["generation_token_ids"] = list(out.token_ids) + if out.logprobs: + choice["generation_log_probs"] = [ + next(iter(lp.values())).logprob for lp in out.logprobs + ] + choices.append(choice) + + prompt_tokens = len(all_outputs[0].prompt_token_ids) if all_outputs else 0 + completion_tokens = sum( + len(out.token_ids) for o in all_outputs for out in o.outputs + ) + + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "model": script_args.model, + "choices": choices, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + # --- Weight sync endpoints (legacy fallback, same as TRL) --- @app.post("/init_communicator/") diff --git a/tests/monkeypatch/test_trainer_loss_calc.py b/tests/e2e/solo/test_trainer_loss_calc.py similarity index 100% rename from tests/monkeypatch/test_trainer_loss_calc.py rename to tests/e2e/solo/test_trainer_loss_calc.py diff --git a/tests/integrations/test_nemo_gym.py b/tests/integrations/test_nemo_gym.py new file mode 100644 index 000000000..7fd53cee0 --- /dev/null +++ b/tests/integrations/test_nemo_gym.py @@ -0,0 +1,621 @@ +"""Unit tests for NeMo Gym integration. + +Tests the core parsing, routing, reward, and plugin wiring logic +without requiring a running NeMo Gym server or GPU. +""" + +import unittest +from unittest.mock import MagicMock, patch + + +class TestParseAgentResponse(unittest.TestCase): + """Tests for _parse_agent_response in multi_turn.py.""" + + def _parse(self, response, eos_token_id=2): + from axolotl.integrations.nemo_gym.multi_turn import _parse_agent_response + + return _parse_agent_response(response, eos_token_id) + + def test_empty_response_returns_defaults(self): + result = self._parse({}) + assert result["prompt_ids"] == [2] + assert result["completion_ids"] == [2] + assert result["env_mask"] == [0] + assert result["reward"] == 0.0 + assert result["num_turns"] == 0 + + def test_error_response_returns_defaults(self): + result = self._parse({"error": "something broke"}) + assert result["reward"] == 0.0 + assert result["num_turns"] == 0 + + def test_single_turn_function_call(self): + response = { + "response": { + "output": [ + { + "type": "function_call", + "name": "guess_word", + "arguments": '{"guess": "crane"}', + "call_id": "call_1", + "prompt_token_ids": [10, 20, 30], + "generation_token_ids": [40, 50], + "generation_log_probs": [-0.1, -0.2], + } + ] + }, + "reward": 0.5, + } + result = self._parse(response) + assert result["prompt_ids"] == [10, 20, 30] + assert result["completion_ids"] == [40, 50] + assert result["env_mask"] == [1, 1] # model tokens + assert result["logprobs"] == [-0.1, -0.2] + assert result["reward"] == 0.5 + assert result["num_turns"] == 1 + + def test_multi_turn_preserves_env_mask(self): + """Second turn's prompt tokens (tool results) get env_mask=0.""" + response = { + "response": { + "output": [ + { + "type": "function_call", + "prompt_token_ids": [10, 20], + "generation_token_ids": [30, 31], + "generation_log_probs": [-0.1, -0.2], + }, + { + "type": "function_call_output", + "output": '{"feedback": "XYGXY"}', + }, + { + "type": "function_call", + # prompt includes original + gen + tool output + "prompt_token_ids": [10, 20, 30, 31, 100, 101, 102], + "generation_token_ids": [40, 41], + "generation_log_probs": [-0.3, -0.4], + }, + ] + }, + "reward": 0.3, + } + result = self._parse(response) + assert result["prompt_ids"] == [10, 20] + # completion = gen1 + tool_result + gen2 + assert result["completion_ids"] == [30, 31, 100, 101, 102, 40, 41] + # env_mask: gen1=model(1), tool=env(0), gen2=model(1) + assert result["env_mask"] == [1, 1, 0, 0, 0, 1, 1] + assert result["num_turns"] == 2 + + def test_empty_output_preserves_reward(self): + response = { + "response": {"output": []}, + "reward": 0.42, + } + result = self._parse(response) + assert result["reward"] == 0.42 + + def test_message_only_output(self): + """A message with text but no function calls.""" + response = { + "response": { + "output": [ + { + "type": "message", + "role": "assistant", + "content": [ + {"type": "output_text", "text": "I'll guess crane."} + ], + "prompt_token_ids": [10, 20], + "generation_token_ids": [30, 31, 32], + "generation_log_probs": [-0.1, -0.2, -0.3], + } + ] + }, + "reward": 0.1, + } + result = self._parse(response) + assert result["num_turns"] == 1 + assert result["completion_ids"] == [30, 31, 32] + assert result["env_mask"] == [1, 1, 1] + + +class TestRewardEnv(unittest.TestCase): + """Tests for reward_env passthrough function.""" + + def test_with_list_rewards(self): + from axolotl.integrations.nemo_gym.rewards import reward_env + + result = reward_env([["comp1"], ["comp2"]], env_reward=[0.5, 0.8]) + assert result == [0.5, 0.8] + + def test_with_scalar_reward(self): + from axolotl.integrations.nemo_gym.rewards import reward_env + + result = reward_env([["comp1"], ["comp2"]], env_reward=0.7) + assert result == [0.7, 0.7] + + def test_missing_reward_returns_zeros(self): + from axolotl.integrations.nemo_gym.rewards import reward_env + + result = reward_env([["comp1"], ["comp2"]]) + assert result == [0.0, 0.0] + + +class TestRewardNemoGymVerify(unittest.TestCase): + """Tests for reward_nemo_gym_verify with mocked HTTP.""" + + @patch("axolotl.integrations.nemo_gym.rewards._get_verify_urls") + @patch("axolotl.integrations.nemo_gym.rewards.requests") + def test_calls_verify_endpoint(self, mock_requests, mock_get_urls): + from axolotl.integrations.nemo_gym.rewards import reward_nemo_gym_verify + + mock_get_urls.return_value = {"wordle": "http://localhost:9999/verify"} + mock_resp = MagicMock() + mock_resp.ok = True + mock_resp.json.return_value = {"reward": 0.75} + mock_requests.post.return_value = mock_resp + + result = reward_nemo_gym_verify( + completions=[[{"role": "assistant", "content": "crane"}]], + prompts=[[{"role": "user", "content": "Guess a word"}]], + resources_server_ref=[{"name": "wordle"}], + verify_extra=[{}], + ) + + assert result == [0.75] + mock_requests.post.assert_called_once() + + @patch("axolotl.integrations.nemo_gym.rewards._get_verify_urls") + def test_missing_server_returns_zero(self, mock_get_urls): + from axolotl.integrations.nemo_gym.rewards import reward_nemo_gym_verify + + mock_get_urls.return_value = {} + + result = reward_nemo_gym_verify( + completions=[[{"role": "assistant", "content": "crane"}]], + prompts=[[{"role": "user", "content": "Guess"}]], + resources_server_ref=[{"name": "unknown_server"}], + verify_extra=[{}], + ) + assert result == [0.0] + + +class TestNormalizeHost(unittest.TestCase): + """Tests for server.py _normalize_host helper.""" + + def test_zero_addr_normalized(self): + from axolotl.integrations.nemo_gym.server import _normalize_host + + assert _normalize_host("0.0.0.0") == "127.0.0.1" + + def test_localhost_normalized(self): + from axolotl.integrations.nemo_gym.server import _normalize_host + + assert _normalize_host("localhost") == "127.0.0.1" + + def test_loopback_passthrough(self): + from axolotl.integrations.nemo_gym.server import _normalize_host + + assert _normalize_host("127.0.0.1") == "127.0.0.1" + + def test_custom_fallback(self): + from axolotl.integrations.nemo_gym.server import _normalize_host + + assert _normalize_host("0.0.0.0", fallback="10.0.0.1") == "10.0.0.1" + + def test_real_ip_passthrough(self): + from axolotl.integrations.nemo_gym.server import _normalize_host + + assert _normalize_host("192.168.1.50") == "192.168.1.50" + + +class TestDatasetLookupKeying(unittest.TestCase): + """Verify dataset lookup uses last message content as key.""" + + def test_single_message_prompt(self): + """Single-message prompt: [0] == [-1], both work.""" + prompt = [{"role": "user", "content": "Play Wordle!"}] + assert prompt[0]["content"] == prompt[-1]["content"] + + def test_multi_message_prompt_uses_last(self): + """Multi-message prompt: must use [-1] to match data_producer lookup.""" + prompt = [ + {"role": "system", "content": "You are a game player."}, + {"role": "user", "content": "Play Wordle!"}, + ] + # data_producer.py line 92 uses prompt[-1] + key = prompt[-1]["content"] + assert key == "Play Wordle!" + # Old code used prompt[0] which would be wrong here + assert prompt[0]["content"] != key + + +class TestAgentRefPreservation(unittest.TestCase): + """Verify agent_ref is preserved through the dispatch chain.""" + + def test_data_producer_preserves_agent_ref(self): + """Simulates the data_producer lookup logic.""" + # Simulate what plugin.py builds + dataset_lookup = { + "Play Wordle!": { + "prompt": [{"role": "user", "content": "Play Wordle!"}], + "agent_ref": {"name": "wordle_simple_agent"}, + "verify_extra": { + "responses_create_params": { + "input": [{"role": "user", "content": "Play Wordle!"}] + } + }, + } + } + + # Simulate data_producer.py logic (after fix) + prompt_text = "Play Wordle!" + full_item = dataset_lookup.get(prompt_text, {}) + item = full_item.get("verify_extra", {}) + if "agent_ref" in full_item and "agent_ref" not in item: + item["agent_ref"] = full_item["agent_ref"] + + assert "agent_ref" in item + assert item["agent_ref"]["name"] == "wordle_simple_agent" + + def test_multi_turn_preserves_agent_ref(self): + """Simulates the multi_turn.py dispatch logic.""" + dataset_lookup = { + "Play Wordle!": { + "agent_ref": {"name": "wordle_simple_agent"}, + "verify_extra": { + "responses_create_params": { + "input": [{"role": "user", "content": "Play Wordle!"}] + } + }, + } + } + + # Simulate multi_turn.py logic (after fix) + prompt_str = "Play Wordle!" + full_item = None + for key, val in dataset_lookup.items(): + if isinstance(key, str) and prompt_str == key: + full_item = val + break + + dispatched = full_item.get("verify_extra", full_item) + if isinstance(dispatched, dict) and "agent_ref" not in dispatched: + agent_ref = full_item.get("agent_ref") + if agent_ref: + dispatched = {**dispatched, "agent_ref": agent_ref} + + assert "agent_ref" in dispatched + assert dispatched["agent_ref"]["name"] == "wordle_simple_agent" + + +class TestCallAgentsRouting(unittest.TestCase): + """Tests for _call_agents routing via agent_ref.""" + + def test_routes_to_correct_agent(self): + """Items with agent_ref should route to the matching agent server.""" + + agent_servers = { + "wordle_agent": "http://localhost:11111", + "math_agent": "http://localhost:22222", + } + + items = [ + { + "agent_ref": {"name": "wordle_agent"}, + "responses_create_params": { + "input": [{"role": "user", "content": "Play"}] + }, + } + ] + + # We can't actually call the agent, but verify the URL resolution + # by checking _call_agents builds the right request + # The function uses aiohttp — just verify agent_ref lookup works + item = items[0] + agent_ref = item.get("agent_ref", {}) + agent_name = agent_ref.get("name", "") + agent_url = agent_servers.get(agent_name, "") + assert agent_url == "http://localhost:11111" + + def test_fallback_to_first_agent(self): + """Items without agent_ref should use first available agent.""" + agent_servers = {"default_agent": "http://localhost:33333"} + item = { + "responses_create_params": {"input": [{"role": "user", "content": "Hello"}]} + } + agent_ref = item.get("agent_ref", {}) + agent_name = agent_ref.get("name", "") + agent_url = agent_servers.get(agent_name, "") + if not agent_url and agent_servers: + agent_url = next(iter(agent_servers.values())) + assert agent_url == "http://localhost:33333" + + +class TestPluginDefaults(unittest.TestCase): + """Tests for plugin config enforcement.""" + + def test_dataloader_num_workers_forced_to_zero(self): + """Plugin should set dataloader_num_workers=0 for NeMo Gym.""" + + # Simulate the plugin logic + class FakeCfg: + dataloader_num_workers = 4 + nemo_gym_multi_turn = True + + cfg = FakeCfg() + # Replicate plugin.get_training_args logic + if getattr(cfg, "dataloader_num_workers", None) not in (None, 0): + pass # would log warning + cfg.dataloader_num_workers = 0 + assert cfg.dataloader_num_workers == 0 + + def test_dataloader_num_workers_none_stays_zero(self): + class FakeCfg: + dataloader_num_workers = None + + cfg = FakeCfg() + cfg.dataloader_num_workers = 0 + assert cfg.dataloader_num_workers == 0 + + +class TestNemoGymE2E(unittest.TestCase): + """End-to-end test: data producer → agent (mocked) → parse → tensors → rewards. + + Exercises the full NemoGymDataProducer.produce() pipeline with mocked HTTP + responses, verifying that multi-turn Wordle agent responses are correctly + parsed into padded tensors with proper env_mask, logprobs, and rewards. + No GPU or NeMo Gym server required. + """ + + # A realistic 2-turn agent /run response (guess + feedback + guess + done) + AGENT_RESPONSE = { + "response": { + "output": [ + { + "type": "function_call", + "name": "guess_word", + "arguments": '{"guess": "crane"}', + "call_id": "call_1", + "id": "call_1", + "status": "completed", + "prompt_token_ids": [1, 2, 3, 4, 5], + "generation_token_ids": [10, 11, 12, 13], + "generation_log_probs": [-0.1, -0.2, -0.3, -0.4], + }, + { + "type": "function_call_output", + "call_id": "call_1", + "output": '{"feedback":"XYGXY","guesses_remaining":5,"done":false}', + }, + { + "type": "function_call", + "name": "guess_word", + "arguments": '{"guess": "slide"}', + "call_id": "call_2", + "id": "call_2", + "status": "completed", + # prompt = original(5) + gen1(4) + tool_output(3 tokens) + "prompt_token_ids": [1, 2, 3, 4, 5, 10, 11, 12, 13, 50, 51, 52], + "generation_token_ids": [20, 21, 22], + "generation_log_probs": [-0.5, -0.6, -0.7], + }, + ], + }, + "reward": 0.42, + } + + def _make_mock_trainer(self): + """Create a minimal mock trainer with the attributes produce() needs.""" + trainer = MagicMock() + trainer.accelerator.is_main_process = True + trainer.accelerator.device = "cpu" + trainer.max_completion_length = 512 + trainer.temperature = 0.8 + trainer.pad_token_id = 0 + trainer.processing_class.eos_token_id = 2 + trainer.processing_class.batch_decode.return_value = ["crane slide"] + return trainer + + @patch("axolotl.integrations.nemo_gym.data_producer._call_agents") + def test_produce_returns_valid_rollout_dataset(self, mock_call_agents): + """Full pipeline: produce() → _call_agents (mocked) → parse → RolloutDataset.""" + + from axolotl.integrations.nemo_gym.data_producer import NemoGymDataProducer + + # Mock _call_agents — it's async, so return a coroutine + async def fake_call_agents(**kwargs): + return [self.AGENT_RESPONSE, self.AGENT_RESPONSE] + + mock_call_agents.side_effect = fake_call_agents + + # Build a minimal mock of GRPODataProducer's __init__ dependencies + # We can't easily call super().__init__, so we'll set attributes directly + producer = NemoGymDataProducer.__new__(NemoGymDataProducer) + producer._agent_servers = {"wordle_agent": "http://mock:9999"} + producer._dataset_lookup = { + "Play Wordle!": { + "agent_ref": {"name": "wordle_agent"}, + "verify_extra": { + "responses_create_params": { + "input": [{"role": "user", "content": "Play Wordle!"}], + } + }, + } + } + producer._request_timeout = 30 + producer._num_generations = 2 + + # Mock the trainer + trainer = self._make_mock_trainer() + producer._trainer = trainer + + # Mock the prompt iterator (returns a batch of 1 input) + producer._prompt_iter = iter( + [ + [ + { + "prompt": [{"role": "user", "content": "Play Wordle!"}], + } + ] + ] + ) + producer._prompt_dl = [ + [{"prompt": [{"role": "user", "content": "Play Wordle!"}]}] + ] + + # Call produce + result = producer.produce(model=MagicMock(), global_step=1) + + # Verify result structure + assert result is not None + data = result._data + + # Check tensor shapes — 2 rollouts (num_generations=2) + assert data["prompt_ids"].shape[0] == 2 + assert data["completion_ids"].shape[0] == 2 + assert data["completion_mask"].shape[0] == 2 + assert data["sampling_per_token_logps"].shape[0] == 2 + assert data["tool_mask"].shape[0] == 2 + + # Verify completion content — each rollout should have: + # gen1(4) + tool_output(3) + gen2(3) = 10 tokens + # (padded to same length across the batch, but both are same here) + comp_len = data["completion_mask"][0].sum().item() + assert comp_len == 10, f"Expected 10 completion tokens, got {comp_len}" + + # Verify env_mask: gen1=1,1,1,1 tool=0,0,0 gen2=1,1,1 + tool_mask = data["tool_mask"][0][:comp_len].tolist() + assert tool_mask == [1, 1, 1, 1, 0, 0, 0, 1, 1, 1] + + # Verify logprobs are populated (use approx for float32 precision) + import pytest + + logps = data["sampling_per_token_logps"][0][:comp_len].tolist() + assert logps[:4] == pytest.approx([-0.1, -0.2, -0.3, -0.4], abs=1e-6) + assert logps[4:7] == pytest.approx([0.0, 0.0, 0.0], abs=1e-6) + assert logps[7:10] == pytest.approx([-0.5, -0.6, -0.7], abs=1e-6) + + # Verify rewards were injected into inputs + assert data["_deferred_inputs"][0]["env_reward"] == 0.42 + assert data["_deferred_inputs"][1]["env_reward"] == 0.42 + + # Verify deferred scoring markers + assert data["_pending_policy_logps"] is True + + @patch("axolotl.integrations.nemo_gym.data_producer._call_agents") + def test_produce_handles_failed_agent_response(self, mock_call_agents): + """Failed agent responses should produce default (length-1) rollouts.""" + + from axolotl.integrations.nemo_gym.data_producer import NemoGymDataProducer + + # One success, one failure — async mock + async def fake_call_agents(**kwargs): + return [ + self.AGENT_RESPONSE, + { + "error": "Connection timeout", + "response": {"output": []}, + "reward": 0.0, + }, + ] + + mock_call_agents.side_effect = fake_call_agents + + producer = NemoGymDataProducer.__new__(NemoGymDataProducer) + producer._agent_servers = {"wordle_agent": "http://mock:9999"} + producer._dataset_lookup = {} + producer._request_timeout = 30 + producer._num_generations = 2 + producer._trainer = self._make_mock_trainer() + producer._prompt_iter = iter( + [[{"prompt": [{"role": "user", "content": "Play!"}]}]] + ) + producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]] + + result = producer.produce(model=MagicMock(), global_step=1) + + assert result is not None + data = result._data + + # Both rollouts present + assert data["completion_ids"].shape[0] == 2 + + # First rollout has real tokens, second has just eos (length 1) + mask0 = data["completion_mask"][0].sum().item() + mask1 = data["completion_mask"][1].sum().item() + assert mask0 == 10 # full response + assert mask1 == 1 # default fallback (just eos) + + # Rewards: success=0.42, failure=0.0 + assert data["_deferred_inputs"][0]["env_reward"] == 0.42 + assert data["_deferred_inputs"][1]["env_reward"] == 0.0 + + @patch("axolotl.integrations.nemo_gym.rewards._get_verify_urls") + @patch("axolotl.integrations.nemo_gym.rewards.requests") + def test_reward_functions_chain(self, mock_requests, mock_get_urls): + """Test that reward_env and reward_nemo_gym_verify can be used together.""" + from axolotl.integrations.nemo_gym.rewards import ( + reward_env, + reward_nemo_gym_verify, + ) + + completions = [[{"role": "assistant", "content": "crane"}]] + prompts = [[{"role": "user", "content": "Guess"}]] + + # reward_env: passthrough from agent + env_result = reward_env(completions, prompts, env_reward=[0.42]) + assert env_result == [0.42] + + # reward_nemo_gym_verify: calls /verify + mock_get_urls.return_value = {"wordle": "http://localhost:9999/verify"} + mock_resp = MagicMock() + mock_resp.ok = True + mock_resp.json.return_value = {"reward": 0.75} + mock_requests.post.return_value = mock_resp + + verify_result = reward_nemo_gym_verify( + completions, + prompts, + resources_server_ref=[{"name": "wordle"}], + verify_extra=[{}], + ) + assert verify_result == [0.75] + + # Both rewards can coexist (as they would in a multi-reward config) + combined = [e + v for e, v in zip(env_result, verify_result, strict=True)] + assert combined == [1.17] + + +class TestLoRASyncSetup(unittest.TestCase): + """Tests for _setup_lora_sync delegation logic.""" + + def test_delegates_to_async_trainer(self): + """When trainer has _sync_lora_adapter, the closure should delegate.""" + from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin + + plugin = NemoGymPlugin.__new__(NemoGymPlugin) + + trainer = MagicMock() + trainer._sync_lora_adapter = MagicMock() + trainer.vllm_generation = MagicMock() + + plugin._setup_lora_sync(trainer) + + # The closure should be installed + trainer.vllm_generation.sync_weights() + trainer._sync_lora_adapter.assert_called_once() + + def test_check_lora_endpoint_skips_non_main_rank(self): + """_check_lora_endpoint should not crash when vllm_client is absent (rank 1).""" + from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin + + vllm_gen = MagicMock(spec=[]) # No attributes at all + # Should not raise + NemoGymPlugin._check_lora_endpoint(vllm_gen) + + +if __name__ == "__main__": + unittest.main()