Nemo gym integration (#3516) [skip ci]

* nemo gym integration with grpo wip

* mostly working

* cleanup

* simplify

* update docs

* nemo gym support wip

* cleanup

* chore: lint

* address PR review and add more tests

* chore: lint

* post merge lora fixes for CI (#3536) [skip ci]

* post merge lora fixes for CI

* handle lora kernel auto-enable for moe without grouped_mm

* prefer not to import torch in schema validation

* address pr comments, add timeout, add tests

* roundup_power2_divisions not needed with newer pytorch versions (#3540)

* roundup_power2_divisions not needed with newer pytorch versions

* remove typo

* update qwen3.5 moe 35b-a3b yaml for 5090

* more bug fixes

* fix tests to match updated trainer

* don't use fa2 for hooks test

* reset plugins on the instance

* retry download

* fix references to renamed axolotl_cfg property on trainer

* Fix ref to trainer cfg

* fix: robust handling of race condition on patching check (#3543) [skip ci]

* EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]

* EBFT wip

* fixes

* more fixeS

* add missing strided module

* ebft fixes for multi-turn

* make ebft work with async

* add example for ebft w qwen3.5

* fix for split thinking and update yaml for lora over linear attention only

* enforce_eager for vllm arg in schema

* fix sync weights

* fix multi-gpu

* handle updated sig for mm

* ddp fixes

* improve multi-gpu handling, don't calculate logits, adaptive completion length

* chore: lint

* chore: lint

* support completion_mean

* Address corereview feedback

* clamp min IS ratio

* Address PR code review

* more fixes identified

* address code review

* Fix property from rebase conflict

* fix for ebft sync and update docs

* make trainer loss patch check a solo test

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Wing Lian
2026-03-25 07:38:06 -04:00
committed by GitHub
parent 2fb72798e0
commit c2bd75aff6
20 changed files with 3592 additions and 19 deletions

View File

@@ -0,0 +1,621 @@
"""Unit tests for NeMo Gym integration.
Tests the core parsing, routing, reward, and plugin wiring logic
without requiring a running NeMo Gym server or GPU.
"""
import unittest
from unittest.mock import MagicMock, patch
class TestParseAgentResponse(unittest.TestCase):
"""Tests for _parse_agent_response in multi_turn.py."""
def _parse(self, response, eos_token_id=2):
from axolotl.integrations.nemo_gym.multi_turn import _parse_agent_response
return _parse_agent_response(response, eos_token_id)
def test_empty_response_returns_defaults(self):
result = self._parse({})
assert result["prompt_ids"] == [2]
assert result["completion_ids"] == [2]
assert result["env_mask"] == [0]
assert result["reward"] == 0.0
assert result["num_turns"] == 0
def test_error_response_returns_defaults(self):
result = self._parse({"error": "something broke"})
assert result["reward"] == 0.0
assert result["num_turns"] == 0
def test_single_turn_function_call(self):
response = {
"response": {
"output": [
{
"type": "function_call",
"name": "guess_word",
"arguments": '{"guess": "crane"}',
"call_id": "call_1",
"prompt_token_ids": [10, 20, 30],
"generation_token_ids": [40, 50],
"generation_log_probs": [-0.1, -0.2],
}
]
},
"reward": 0.5,
}
result = self._parse(response)
assert result["prompt_ids"] == [10, 20, 30]
assert result["completion_ids"] == [40, 50]
assert result["env_mask"] == [1, 1] # model tokens
assert result["logprobs"] == [-0.1, -0.2]
assert result["reward"] == 0.5
assert result["num_turns"] == 1
def test_multi_turn_preserves_env_mask(self):
"""Second turn's prompt tokens (tool results) get env_mask=0."""
response = {
"response": {
"output": [
{
"type": "function_call",
"prompt_token_ids": [10, 20],
"generation_token_ids": [30, 31],
"generation_log_probs": [-0.1, -0.2],
},
{
"type": "function_call_output",
"output": '{"feedback": "XYGXY"}',
},
{
"type": "function_call",
# prompt includes original + gen + tool output
"prompt_token_ids": [10, 20, 30, 31, 100, 101, 102],
"generation_token_ids": [40, 41],
"generation_log_probs": [-0.3, -0.4],
},
]
},
"reward": 0.3,
}
result = self._parse(response)
assert result["prompt_ids"] == [10, 20]
# completion = gen1 + tool_result + gen2
assert result["completion_ids"] == [30, 31, 100, 101, 102, 40, 41]
# env_mask: gen1=model(1), tool=env(0), gen2=model(1)
assert result["env_mask"] == [1, 1, 0, 0, 0, 1, 1]
assert result["num_turns"] == 2
def test_empty_output_preserves_reward(self):
response = {
"response": {"output": []},
"reward": 0.42,
}
result = self._parse(response)
assert result["reward"] == 0.42
def test_message_only_output(self):
"""A message with text but no function calls."""
response = {
"response": {
"output": [
{
"type": "message",
"role": "assistant",
"content": [
{"type": "output_text", "text": "I'll guess crane."}
],
"prompt_token_ids": [10, 20],
"generation_token_ids": [30, 31, 32],
"generation_log_probs": [-0.1, -0.2, -0.3],
}
]
},
"reward": 0.1,
}
result = self._parse(response)
assert result["num_turns"] == 1
assert result["completion_ids"] == [30, 31, 32]
assert result["env_mask"] == [1, 1, 1]
class TestRewardEnv(unittest.TestCase):
"""Tests for reward_env passthrough function."""
def test_with_list_rewards(self):
from axolotl.integrations.nemo_gym.rewards import reward_env
result = reward_env([["comp1"], ["comp2"]], env_reward=[0.5, 0.8])
assert result == [0.5, 0.8]
def test_with_scalar_reward(self):
from axolotl.integrations.nemo_gym.rewards import reward_env
result = reward_env([["comp1"], ["comp2"]], env_reward=0.7)
assert result == [0.7, 0.7]
def test_missing_reward_returns_zeros(self):
from axolotl.integrations.nemo_gym.rewards import reward_env
result = reward_env([["comp1"], ["comp2"]])
assert result == [0.0, 0.0]
class TestRewardNemoGymVerify(unittest.TestCase):
"""Tests for reward_nemo_gym_verify with mocked HTTP."""
@patch("axolotl.integrations.nemo_gym.rewards._get_verify_urls")
@patch("axolotl.integrations.nemo_gym.rewards.requests")
def test_calls_verify_endpoint(self, mock_requests, mock_get_urls):
from axolotl.integrations.nemo_gym.rewards import reward_nemo_gym_verify
mock_get_urls.return_value = {"wordle": "http://localhost:9999/verify"}
mock_resp = MagicMock()
mock_resp.ok = True
mock_resp.json.return_value = {"reward": 0.75}
mock_requests.post.return_value = mock_resp
result = reward_nemo_gym_verify(
completions=[[{"role": "assistant", "content": "crane"}]],
prompts=[[{"role": "user", "content": "Guess a word"}]],
resources_server_ref=[{"name": "wordle"}],
verify_extra=[{}],
)
assert result == [0.75]
mock_requests.post.assert_called_once()
@patch("axolotl.integrations.nemo_gym.rewards._get_verify_urls")
def test_missing_server_returns_zero(self, mock_get_urls):
from axolotl.integrations.nemo_gym.rewards import reward_nemo_gym_verify
mock_get_urls.return_value = {}
result = reward_nemo_gym_verify(
completions=[[{"role": "assistant", "content": "crane"}]],
prompts=[[{"role": "user", "content": "Guess"}]],
resources_server_ref=[{"name": "unknown_server"}],
verify_extra=[{}],
)
assert result == [0.0]
class TestNormalizeHost(unittest.TestCase):
"""Tests for server.py _normalize_host helper."""
def test_zero_addr_normalized(self):
from axolotl.integrations.nemo_gym.server import _normalize_host
assert _normalize_host("0.0.0.0") == "127.0.0.1"
def test_localhost_normalized(self):
from axolotl.integrations.nemo_gym.server import _normalize_host
assert _normalize_host("localhost") == "127.0.0.1"
def test_loopback_passthrough(self):
from axolotl.integrations.nemo_gym.server import _normalize_host
assert _normalize_host("127.0.0.1") == "127.0.0.1"
def test_custom_fallback(self):
from axolotl.integrations.nemo_gym.server import _normalize_host
assert _normalize_host("0.0.0.0", fallback="10.0.0.1") == "10.0.0.1"
def test_real_ip_passthrough(self):
from axolotl.integrations.nemo_gym.server import _normalize_host
assert _normalize_host("192.168.1.50") == "192.168.1.50"
class TestDatasetLookupKeying(unittest.TestCase):
"""Verify dataset lookup uses last message content as key."""
def test_single_message_prompt(self):
"""Single-message prompt: [0] == [-1], both work."""
prompt = [{"role": "user", "content": "Play Wordle!"}]
assert prompt[0]["content"] == prompt[-1]["content"]
def test_multi_message_prompt_uses_last(self):
"""Multi-message prompt: must use [-1] to match data_producer lookup."""
prompt = [
{"role": "system", "content": "You are a game player."},
{"role": "user", "content": "Play Wordle!"},
]
# data_producer.py line 92 uses prompt[-1]
key = prompt[-1]["content"]
assert key == "Play Wordle!"
# Old code used prompt[0] which would be wrong here
assert prompt[0]["content"] != key
class TestAgentRefPreservation(unittest.TestCase):
"""Verify agent_ref is preserved through the dispatch chain."""
def test_data_producer_preserves_agent_ref(self):
"""Simulates the data_producer lookup logic."""
# Simulate what plugin.py builds
dataset_lookup = {
"Play Wordle!": {
"prompt": [{"role": "user", "content": "Play Wordle!"}],
"agent_ref": {"name": "wordle_simple_agent"},
"verify_extra": {
"responses_create_params": {
"input": [{"role": "user", "content": "Play Wordle!"}]
}
},
}
}
# Simulate data_producer.py logic (after fix)
prompt_text = "Play Wordle!"
full_item = dataset_lookup.get(prompt_text, {})
item = full_item.get("verify_extra", {})
if "agent_ref" in full_item and "agent_ref" not in item:
item["agent_ref"] = full_item["agent_ref"]
assert "agent_ref" in item
assert item["agent_ref"]["name"] == "wordle_simple_agent"
def test_multi_turn_preserves_agent_ref(self):
"""Simulates the multi_turn.py dispatch logic."""
dataset_lookup = {
"Play Wordle!": {
"agent_ref": {"name": "wordle_simple_agent"},
"verify_extra": {
"responses_create_params": {
"input": [{"role": "user", "content": "Play Wordle!"}]
}
},
}
}
# Simulate multi_turn.py logic (after fix)
prompt_str = "Play Wordle!"
full_item = None
for key, val in dataset_lookup.items():
if isinstance(key, str) and prompt_str == key:
full_item = val
break
dispatched = full_item.get("verify_extra", full_item)
if isinstance(dispatched, dict) and "agent_ref" not in dispatched:
agent_ref = full_item.get("agent_ref")
if agent_ref:
dispatched = {**dispatched, "agent_ref": agent_ref}
assert "agent_ref" in dispatched
assert dispatched["agent_ref"]["name"] == "wordle_simple_agent"
class TestCallAgentsRouting(unittest.TestCase):
"""Tests for _call_agents routing via agent_ref."""
def test_routes_to_correct_agent(self):
"""Items with agent_ref should route to the matching agent server."""
agent_servers = {
"wordle_agent": "http://localhost:11111",
"math_agent": "http://localhost:22222",
}
items = [
{
"agent_ref": {"name": "wordle_agent"},
"responses_create_params": {
"input": [{"role": "user", "content": "Play"}]
},
}
]
# We can't actually call the agent, but verify the URL resolution
# by checking _call_agents builds the right request
# The function uses aiohttp — just verify agent_ref lookup works
item = items[0]
agent_ref = item.get("agent_ref", {})
agent_name = agent_ref.get("name", "")
agent_url = agent_servers.get(agent_name, "")
assert agent_url == "http://localhost:11111"
def test_fallback_to_first_agent(self):
"""Items without agent_ref should use first available agent."""
agent_servers = {"default_agent": "http://localhost:33333"}
item = {
"responses_create_params": {"input": [{"role": "user", "content": "Hello"}]}
}
agent_ref = item.get("agent_ref", {})
agent_name = agent_ref.get("name", "")
agent_url = agent_servers.get(agent_name, "")
if not agent_url and agent_servers:
agent_url = next(iter(agent_servers.values()))
assert agent_url == "http://localhost:33333"
class TestPluginDefaults(unittest.TestCase):
"""Tests for plugin config enforcement."""
def test_dataloader_num_workers_forced_to_zero(self):
"""Plugin should set dataloader_num_workers=0 for NeMo Gym."""
# Simulate the plugin logic
class FakeCfg:
dataloader_num_workers = 4
nemo_gym_multi_turn = True
cfg = FakeCfg()
# Replicate plugin.get_training_args logic
if getattr(cfg, "dataloader_num_workers", None) not in (None, 0):
pass # would log warning
cfg.dataloader_num_workers = 0
assert cfg.dataloader_num_workers == 0
def test_dataloader_num_workers_none_stays_zero(self):
class FakeCfg:
dataloader_num_workers = None
cfg = FakeCfg()
cfg.dataloader_num_workers = 0
assert cfg.dataloader_num_workers == 0
class TestNemoGymE2E(unittest.TestCase):
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.
Exercises the full NemoGymDataProducer.produce() pipeline with mocked HTTP
responses, verifying that multi-turn Wordle agent responses are correctly
parsed into padded tensors with proper env_mask, logprobs, and rewards.
No GPU or NeMo Gym server required.
"""
# A realistic 2-turn agent /run response (guess + feedback + guess + done)
AGENT_RESPONSE = {
"response": {
"output": [
{
"type": "function_call",
"name": "guess_word",
"arguments": '{"guess": "crane"}',
"call_id": "call_1",
"id": "call_1",
"status": "completed",
"prompt_token_ids": [1, 2, 3, 4, 5],
"generation_token_ids": [10, 11, 12, 13],
"generation_log_probs": [-0.1, -0.2, -0.3, -0.4],
},
{
"type": "function_call_output",
"call_id": "call_1",
"output": '{"feedback":"XYGXY","guesses_remaining":5,"done":false}',
},
{
"type": "function_call",
"name": "guess_word",
"arguments": '{"guess": "slide"}',
"call_id": "call_2",
"id": "call_2",
"status": "completed",
# prompt = original(5) + gen1(4) + tool_output(3 tokens)
"prompt_token_ids": [1, 2, 3, 4, 5, 10, 11, 12, 13, 50, 51, 52],
"generation_token_ids": [20, 21, 22],
"generation_log_probs": [-0.5, -0.6, -0.7],
},
],
},
"reward": 0.42,
}
def _make_mock_trainer(self):
"""Create a minimal mock trainer with the attributes produce() needs."""
trainer = MagicMock()
trainer.accelerator.is_main_process = True
trainer.accelerator.device = "cpu"
trainer.max_completion_length = 512
trainer.temperature = 0.8
trainer.pad_token_id = 0
trainer.processing_class.eos_token_id = 2
trainer.processing_class.batch_decode.return_value = ["crane slide"]
return trainer
@patch("axolotl.integrations.nemo_gym.data_producer._call_agents")
def test_produce_returns_valid_rollout_dataset(self, mock_call_agents):
"""Full pipeline: produce() → _call_agents (mocked) → parse → RolloutDataset."""
from axolotl.integrations.nemo_gym.data_producer import NemoGymDataProducer
# Mock _call_agents — it's async, so return a coroutine
async def fake_call_agents(**kwargs):
return [self.AGENT_RESPONSE, self.AGENT_RESPONSE]
mock_call_agents.side_effect = fake_call_agents
# Build a minimal mock of GRPODataProducer's __init__ dependencies
# We can't easily call super().__init__, so we'll set attributes directly
producer = NemoGymDataProducer.__new__(NemoGymDataProducer)
producer._agent_servers = {"wordle_agent": "http://mock:9999"}
producer._dataset_lookup = {
"Play Wordle!": {
"agent_ref": {"name": "wordle_agent"},
"verify_extra": {
"responses_create_params": {
"input": [{"role": "user", "content": "Play Wordle!"}],
}
},
}
}
producer._request_timeout = 30
producer._num_generations = 2
# Mock the trainer
trainer = self._make_mock_trainer()
producer._trainer = trainer
# Mock the prompt iterator (returns a batch of 1 input)
producer._prompt_iter = iter(
[
[
{
"prompt": [{"role": "user", "content": "Play Wordle!"}],
}
]
]
)
producer._prompt_dl = [
[{"prompt": [{"role": "user", "content": "Play Wordle!"}]}]
]
# Call produce
result = producer.produce(model=MagicMock(), global_step=1)
# Verify result structure
assert result is not None
data = result._data
# Check tensor shapes — 2 rollouts (num_generations=2)
assert data["prompt_ids"].shape[0] == 2
assert data["completion_ids"].shape[0] == 2
assert data["completion_mask"].shape[0] == 2
assert data["sampling_per_token_logps"].shape[0] == 2
assert data["tool_mask"].shape[0] == 2
# Verify completion content — each rollout should have:
# gen1(4) + tool_output(3) + gen2(3) = 10 tokens
# (padded to same length across the batch, but both are same here)
comp_len = data["completion_mask"][0].sum().item()
assert comp_len == 10, f"Expected 10 completion tokens, got {comp_len}"
# Verify env_mask: gen1=1,1,1,1 tool=0,0,0 gen2=1,1,1
tool_mask = data["tool_mask"][0][:comp_len].tolist()
assert tool_mask == [1, 1, 1, 1, 0, 0, 0, 1, 1, 1]
# Verify logprobs are populated (use approx for float32 precision)
import pytest
logps = data["sampling_per_token_logps"][0][:comp_len].tolist()
assert logps[:4] == pytest.approx([-0.1, -0.2, -0.3, -0.4], abs=1e-6)
assert logps[4:7] == pytest.approx([0.0, 0.0, 0.0], abs=1e-6)
assert logps[7:10] == pytest.approx([-0.5, -0.6, -0.7], abs=1e-6)
# Verify rewards were injected into inputs
assert data["_deferred_inputs"][0]["env_reward"] == 0.42
assert data["_deferred_inputs"][1]["env_reward"] == 0.42
# Verify deferred scoring markers
assert data["_pending_policy_logps"] is True
@patch("axolotl.integrations.nemo_gym.data_producer._call_agents")
def test_produce_handles_failed_agent_response(self, mock_call_agents):
"""Failed agent responses should produce default (length-1) rollouts."""
from axolotl.integrations.nemo_gym.data_producer import NemoGymDataProducer
# One success, one failure — async mock
async def fake_call_agents(**kwargs):
return [
self.AGENT_RESPONSE,
{
"error": "Connection timeout",
"response": {"output": []},
"reward": 0.0,
},
]
mock_call_agents.side_effect = fake_call_agents
producer = NemoGymDataProducer.__new__(NemoGymDataProducer)
producer._agent_servers = {"wordle_agent": "http://mock:9999"}
producer._dataset_lookup = {}
producer._request_timeout = 30
producer._num_generations = 2
producer._trainer = self._make_mock_trainer()
producer._prompt_iter = iter(
[[{"prompt": [{"role": "user", "content": "Play!"}]}]]
)
producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]]
result = producer.produce(model=MagicMock(), global_step=1)
assert result is not None
data = result._data
# Both rollouts present
assert data["completion_ids"].shape[0] == 2
# First rollout has real tokens, second has just eos (length 1)
mask0 = data["completion_mask"][0].sum().item()
mask1 = data["completion_mask"][1].sum().item()
assert mask0 == 10 # full response
assert mask1 == 1 # default fallback (just eos)
# Rewards: success=0.42, failure=0.0
assert data["_deferred_inputs"][0]["env_reward"] == 0.42
assert data["_deferred_inputs"][1]["env_reward"] == 0.0
@patch("axolotl.integrations.nemo_gym.rewards._get_verify_urls")
@patch("axolotl.integrations.nemo_gym.rewards.requests")
def test_reward_functions_chain(self, mock_requests, mock_get_urls):
"""Test that reward_env and reward_nemo_gym_verify can be used together."""
from axolotl.integrations.nemo_gym.rewards import (
reward_env,
reward_nemo_gym_verify,
)
completions = [[{"role": "assistant", "content": "crane"}]]
prompts = [[{"role": "user", "content": "Guess"}]]
# reward_env: passthrough from agent
env_result = reward_env(completions, prompts, env_reward=[0.42])
assert env_result == [0.42]
# reward_nemo_gym_verify: calls /verify
mock_get_urls.return_value = {"wordle": "http://localhost:9999/verify"}
mock_resp = MagicMock()
mock_resp.ok = True
mock_resp.json.return_value = {"reward": 0.75}
mock_requests.post.return_value = mock_resp
verify_result = reward_nemo_gym_verify(
completions,
prompts,
resources_server_ref=[{"name": "wordle"}],
verify_extra=[{}],
)
assert verify_result == [0.75]
# Both rewards can coexist (as they would in a multi-reward config)
combined = [e + v for e, v in zip(env_result, verify_result, strict=True)]
assert combined == [1.17]
class TestLoRASyncSetup(unittest.TestCase):
"""Tests for _setup_lora_sync delegation logic."""
def test_delegates_to_async_trainer(self):
"""When trainer has _sync_lora_adapter, the closure should delegate."""
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin.__new__(NemoGymPlugin)
trainer = MagicMock()
trainer._sync_lora_adapter = MagicMock()
trainer.vllm_generation = MagicMock()
plugin._setup_lora_sync(trainer)
# The closure should be installed
trainer.vllm_generation.sync_weights()
trainer._sync_lora_adapter.assert_called_once()
def test_check_lora_endpoint_skips_non_main_rank(self):
"""_check_lora_endpoint should not crash when vllm_client is absent (rank 1)."""
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
vllm_gen = MagicMock(spec=[]) # No attributes at all
# Should not raise
NemoGymPlugin._check_lora_endpoint(vllm_gen)
if __name__ == "__main__":
unittest.main()