Files
axolotl/tests/integrations/test_nemo_gym.py
2026-04-22 09:05:46 -04:00

944 lines
35 KiB
Python

"""Unit tests for NeMo Gym integration.
Tests the core parsing, routing, reward, and plugin wiring logic
without requiring a running NeMo Gym server or GPU.
"""
import unittest
from unittest.mock import MagicMock, patch
class TestParseAgentResponse(unittest.TestCase):
"""Tests for _parse_agent_response in multi_turn.py."""
def _parse(self, response, eos_token_id=2):
from axolotl.integrations.nemo_gym.multi_turn import _parse_agent_response
return _parse_agent_response(response, eos_token_id)
def test_empty_response_returns_defaults(self):
result = self._parse({})
assert result["prompt_ids"] == [2]
assert result["completion_ids"] == [2]
assert result["env_mask"] == [0]
assert result["reward"] == 0.0
assert result["num_turns"] == 0
def test_error_response_returns_defaults(self):
result = self._parse({"error": "something broke"})
assert result["reward"] == 0.0
assert result["num_turns"] == 0
def test_single_turn_function_call(self):
response = {
"response": {
"output": [
{
"type": "function_call",
"name": "guess_word",
"arguments": '{"guess": "crane"}',
"call_id": "call_1",
"prompt_token_ids": [10, 20, 30],
"generation_token_ids": [40, 50],
"generation_log_probs": [-0.1, -0.2],
}
]
},
"reward": 0.5,
}
result = self._parse(response)
assert result["prompt_ids"] == [10, 20, 30]
assert result["completion_ids"] == [40, 50]
assert result["env_mask"] == [1, 1] # model tokens
assert result["logprobs"] == [-0.1, -0.2]
assert result["reward"] == 0.5
assert result["num_turns"] == 1
def test_multi_turn_preserves_env_mask(self):
"""Second turn's prompt tokens (tool results) get env_mask=0."""
response = {
"response": {
"output": [
{
"type": "function_call",
"prompt_token_ids": [10, 20],
"generation_token_ids": [30, 31],
"generation_log_probs": [-0.1, -0.2],
},
{
"type": "function_call_output",
"output": '{"feedback": "XYGXY"}',
},
{
"type": "function_call",
# prompt includes original + gen + tool output
"prompt_token_ids": [10, 20, 30, 31, 100, 101, 102],
"generation_token_ids": [40, 41],
"generation_log_probs": [-0.3, -0.4],
},
]
},
"reward": 0.3,
}
result = self._parse(response)
assert result["prompt_ids"] == [10, 20]
# completion = gen1 + tool_result + gen2
assert result["completion_ids"] == [30, 31, 100, 101, 102, 40, 41]
# env_mask: gen1=model(1), tool=env(0), gen2=model(1)
assert result["env_mask"] == [1, 1, 0, 0, 0, 1, 1]
assert result["num_turns"] == 2
def test_empty_output_preserves_reward(self):
response = {
"response": {"output": []},
"reward": 0.42,
}
result = self._parse(response)
assert result["reward"] == 0.42
def test_message_only_output(self):
"""A message with text but no function calls."""
response = {
"response": {
"output": [
{
"type": "message",
"role": "assistant",
"content": [
{"type": "output_text", "text": "I'll guess crane."}
],
"prompt_token_ids": [10, 20],
"generation_token_ids": [30, 31, 32],
"generation_log_probs": [-0.1, -0.2, -0.3],
}
]
},
"reward": 0.1,
}
result = self._parse(response)
assert result["num_turns"] == 1
assert result["completion_ids"] == [30, 31, 32]
assert result["env_mask"] == [1, 1, 1]
class TestRewardEnv(unittest.TestCase):
"""Tests for reward_env passthrough function."""
def test_with_list_rewards(self):
from axolotl.integrations.nemo_gym.rewards import reward_env
result = reward_env([["comp1"], ["comp2"]], env_reward=[0.5, 0.8])
assert result == [0.5, 0.8]
def test_with_scalar_reward(self):
from axolotl.integrations.nemo_gym.rewards import reward_env
result = reward_env([["comp1"], ["comp2"]], env_reward=0.7)
assert result == [0.7, 0.7]
def test_missing_reward_returns_zeros(self):
from axolotl.integrations.nemo_gym.rewards import reward_env
result = reward_env([["comp1"], ["comp2"]])
assert result == [0.0, 0.0]
class TestRewardNemoGymVerify(unittest.TestCase):
"""Tests for reward_nemo_gym_verify with mocked HTTP."""
@patch("axolotl.integrations.nemo_gym.rewards._get_verify_urls")
@patch("axolotl.integrations.nemo_gym.rewards.requests")
def test_calls_verify_endpoint(self, mock_requests, mock_get_urls):
from axolotl.integrations.nemo_gym.rewards import reward_nemo_gym_verify
mock_get_urls.return_value = {"wordle": "http://localhost:9999/verify"}
mock_resp = MagicMock()
mock_resp.ok = True
mock_resp.json.return_value = {"reward": 0.75}
mock_requests.post.return_value = mock_resp
result = reward_nemo_gym_verify(
completions=[[{"role": "assistant", "content": "crane"}]],
prompts=[[{"role": "user", "content": "Guess a word"}]],
resources_server_ref=[{"name": "wordle"}],
verify_extra=[{}],
)
assert result == [0.75]
mock_requests.post.assert_called_once()
@patch("axolotl.integrations.nemo_gym.rewards._get_verify_urls")
def test_missing_server_returns_zero(self, mock_get_urls):
from axolotl.integrations.nemo_gym.rewards import reward_nemo_gym_verify
mock_get_urls.return_value = {}
result = reward_nemo_gym_verify(
completions=[[{"role": "assistant", "content": "crane"}]],
prompts=[[{"role": "user", "content": "Guess"}]],
resources_server_ref=[{"name": "unknown_server"}],
verify_extra=[{}],
)
assert result == [0.0]
class TestNormalizeHost(unittest.TestCase):
"""Tests for server.py _normalize_host helper."""
def test_zero_addr_normalized(self):
from axolotl.integrations.nemo_gym.server import _normalize_host
assert _normalize_host("0.0.0.0") == "127.0.0.1"
def test_localhost_normalized(self):
from axolotl.integrations.nemo_gym.server import _normalize_host
assert _normalize_host("localhost") == "127.0.0.1"
def test_loopback_passthrough(self):
from axolotl.integrations.nemo_gym.server import _normalize_host
assert _normalize_host("127.0.0.1") == "127.0.0.1"
def test_custom_fallback(self):
from axolotl.integrations.nemo_gym.server import _normalize_host
assert _normalize_host("0.0.0.0", fallback="10.0.0.1") == "10.0.0.1"
def test_real_ip_passthrough(self):
from axolotl.integrations.nemo_gym.server import _normalize_host
assert _normalize_host("192.168.1.50") == "192.168.1.50"
class TestDatasetLookupKeying(unittest.TestCase):
"""Verify dataset lookup uses last message content as key."""
def test_single_message_prompt(self):
"""Single-message prompt: [0] == [-1], both work."""
prompt = [{"role": "user", "content": "Play Wordle!"}]
assert prompt[0]["content"] == prompt[-1]["content"]
def test_multi_message_prompt_uses_last(self):
"""Multi-message prompt: must use [-1] to match data_producer lookup."""
prompt = [
{"role": "system", "content": "You are a game player."},
{"role": "user", "content": "Play Wordle!"},
]
# data_producer.py line 92 uses prompt[-1]
key = prompt[-1]["content"]
assert key == "Play Wordle!"
# Old code used prompt[0] which would be wrong here
assert prompt[0]["content"] != key
class TestAgentRefPreservation(unittest.TestCase):
"""Verify agent_ref is preserved through the dispatch chain."""
def test_data_producer_preserves_agent_ref(self):
"""Simulates the data_producer lookup logic."""
# Simulate what plugin.py builds
dataset_lookup = {
"Play Wordle!": {
"prompt": [{"role": "user", "content": "Play Wordle!"}],
"agent_ref": {"name": "wordle_simple_agent"},
"verify_extra": {
"responses_create_params": {
"input": [{"role": "user", "content": "Play Wordle!"}]
}
},
}
}
# Simulate data_producer.py logic (after fix)
prompt_text = "Play Wordle!"
full_item = dataset_lookup.get(prompt_text, {})
item = full_item.get("verify_extra", {})
if "agent_ref" in full_item and "agent_ref" not in item:
item["agent_ref"] = full_item["agent_ref"]
assert "agent_ref" in item
assert item["agent_ref"]["name"] == "wordle_simple_agent"
def test_multi_turn_preserves_agent_ref(self):
"""Simulates the multi_turn.py dispatch logic."""
dataset_lookup = {
"Play Wordle!": {
"agent_ref": {"name": "wordle_simple_agent"},
"verify_extra": {
"responses_create_params": {
"input": [{"role": "user", "content": "Play Wordle!"}]
}
},
}
}
# Simulate multi_turn.py logic (after fix)
prompt_str = "Play Wordle!"
full_item = None
for key, val in dataset_lookup.items():
if isinstance(key, str) and prompt_str == key:
full_item = val
break
dispatched = full_item.get("verify_extra", full_item)
if isinstance(dispatched, dict) and "agent_ref" not in dispatched:
agent_ref = full_item.get("agent_ref")
if agent_ref:
dispatched = {**dispatched, "agent_ref": agent_ref}
assert "agent_ref" in dispatched
assert dispatched["agent_ref"]["name"] == "wordle_simple_agent"
class TestCallAgentsRouting(unittest.TestCase):
"""Tests for _call_agents routing via agent_ref."""
def test_routes_to_correct_agent(self):
"""Items with agent_ref should route to the matching agent server."""
agent_servers = {
"wordle_agent": "http://localhost:11111",
"math_agent": "http://localhost:22222",
}
items = [
{
"agent_ref": {"name": "wordle_agent"},
"responses_create_params": {
"input": [{"role": "user", "content": "Play"}]
},
}
]
# We can't actually call the agent, but verify the URL resolution
# by checking _call_agents builds the right request
# The function uses aiohttp — just verify agent_ref lookup works
item = items[0]
agent_ref = item.get("agent_ref", {})
agent_name = agent_ref.get("name", "")
agent_url = agent_servers.get(agent_name, "")
assert agent_url == "http://localhost:11111"
def test_fallback_to_first_agent(self):
"""Items without agent_ref should use first available agent."""
agent_servers = {"default_agent": "http://localhost:33333"}
item = {
"responses_create_params": {"input": [{"role": "user", "content": "Hello"}]}
}
agent_ref = item.get("agent_ref", {})
agent_name = agent_ref.get("name", "")
agent_url = agent_servers.get(agent_name, "")
if not agent_url and agent_servers:
agent_url = next(iter(agent_servers.values()))
assert agent_url == "http://localhost:33333"
class TestPluginDefaults(unittest.TestCase):
"""Tests for plugin config enforcement."""
def test_dataloader_num_workers_forced_to_zero(self):
"""Plugin should set dataloader_num_workers=0 for NeMo Gym."""
# Simulate the plugin logic
class FakeCfg:
dataloader_num_workers = 4
nemo_gym_multi_turn = True
cfg = FakeCfg()
# Replicate plugin.get_training_args logic
if getattr(cfg, "dataloader_num_workers", None) not in (None, 0):
pass # would log warning
cfg.dataloader_num_workers = 0
assert cfg.dataloader_num_workers == 0
def test_dataloader_num_workers_none_stays_zero(self):
class FakeCfg:
dataloader_num_workers = None
cfg = FakeCfg()
cfg.dataloader_num_workers = 0
assert cfg.dataloader_num_workers == 0
class TestSelectWeightSyncTransport(unittest.TestCase):
"""Pure-logic table tests for ``select_weight_sync_transport``."""
def _caps(self, **kwargs):
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
c = VLLMWeightSyncCapabilities(probed=True)
for k, v in kwargs.items():
setattr(c, k, v)
return c
def test_lora_with_native_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(lora_filesystem=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
== "lora_filesystem"
)
def test_lora_with_axolotl_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(lora_axolotl=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
== "lora_filesystem"
)
def test_lora_falls_back_to_nccl_when_no_lora_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(nccl=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
== "nccl"
)
def test_full_param_prefers_nccl(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(nccl=True, http_full=True)
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "nccl"
)
def test_full_param_falls_back_to_http(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(http_full=True)
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "http_full"
)
def test_full_param_no_routes_returns_none(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps() # all False
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "none"
)
def test_lora_no_routes_returns_none(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps()
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
== "none"
)
class TestProbeVllmWeightSync(unittest.TestCase):
"""``probe_vllm_weight_sync`` reads a vLLM ``/openapi.json`` and reports caps."""
def test_stock_vllm_with_lora_enabled(self):
"""Stock ``vllm serve --enable-lora`` exposes only LoRA endpoints."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/v1/models": {"get": {}},
"/v1/load_lora_adapter": {"post": {}},
"/v1/unload_lora_adapter": {"post": {}},
"/v1/completions": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.lora_filesystem is True
assert caps.lora_axolotl is False
assert caps.nccl is False
assert caps.http_full is False
def test_axolotl_serve_lora_full_capabilities(self):
"""``axolotl vllm-serve`` exposes NCCL + LoRA + HTTP full sync."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/init_communicator/": {"post": {}},
"/update_named_param/": {"post": {}},
"/batch_update_named_params/": {"post": {}},
"/set_lora_adapter/": {"post": {}},
"/clear_lora_adapter/": {"post": {}},
"/http_update_weights/": {"post": {}},
"/v1/load_lora_adapter": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.nccl is True
assert caps.lora_axolotl is True
assert caps.lora_filesystem is True
assert caps.http_full is True
def test_trl_vllm_serve_nccl_only(self):
"""``trl vllm-serve`` exposes NCCL routes but not LoRA filesystem."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/init_communicator/": {"post": {}},
"/update_named_param/": {"post": {}},
"/batch_update_named_params/": {"post": {}},
"/close_communicator/": {"post": {}},
"/generate/": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.nccl is True
assert caps.lora_filesystem is False
assert caps.lora_axolotl is False
assert caps.http_full is False
def test_unreachable_server_records_error(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
with patch("requests.get") as mock_get:
mock_get.side_effect = ConnectionError("Connection refused")
caps = probe_vllm_weight_sync("http://localhost:9999")
assert caps.probed is False
assert caps.probe_error is not None
assert "ConnectionError" in caps.probe_error
assert caps.nccl is False
assert caps.lora_filesystem is False
class TestPluginWeightSyncEnforcement(unittest.TestCase):
"""End-to-end test of post_trainer_create's transport-selection branch.
The plugin used to silently no-op weight sync when ``vllm_lora_sync: false``,
leaving the trainer learning in isolation while vLLM kept serving the
unmodified base model. After the fix:
- LoRA + LoRA-loading endpoint → installs filesystem LoRA sync
- LoRA + only NCCL endpoint → uses NCCL broadcast
- Full FT + NCCL endpoint → uses NCCL broadcast (standard TRL flow)
- Full FT + HTTP endpoint → raises NotImplementedError (step 3)
- No usable transport → raises ValueError with a precise diagnosis
"""
@staticmethod
def _fake_cfg(adapter, vllm_lora_sync):
class FakeTRL:
pass
class FakeCfg:
pass
trl = FakeTRL()
trl.vllm_lora_sync = vllm_lora_sync
trl.vllm_server_host = "127.0.0.1"
trl.vllm_server_port = 8000
cfg = FakeCfg()
cfg.nemo_gym_enabled = True
cfg.nemo_gym_model_name = None
cfg.base_model = "test/model"
cfg.nemo_gym_verify_timeout = 30
cfg.nemo_gym_multi_turn = True
cfg.adapter = adapter
cfg.trl = trl
return cfg
@staticmethod
def _fake_trainer():
class FakeVLLMGen:
sync_weights = staticmethod(lambda: None)
class FakeTrainer:
vllm_generation = FakeVLLMGen()
return FakeTrainer()
@staticmethod
def _caps(**kwargs):
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
c = VLLMWeightSyncCapabilities(probed=True)
for k, v in kwargs.items():
setattr(c, k, v)
return c
def test_lora_with_lora_endpoint_installs_filesystem_sync(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(lora_filesystem=True)
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
trainer = self._fake_trainer()
with (
patch.object(plugin, "_setup_lora_sync") as setup,
patch.object(plugin, "_check_lora_endpoint") as check,
patch.object(plugin, "_wire_multi_turn") as wire,
):
plugin.post_trainer_create(cfg, trainer)
setup.assert_called_once()
check.assert_called_once()
wire.assert_called_once()
def test_lora_with_no_routes_raises_with_lora_specific_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps() # all False, but probed
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
msg = str(ctx.exception)
assert "no-op trainer" in msg
assert "load_lora_adapter" in msg
assert "VLLM_ALLOW_RUNTIME_LORA_UPDATING" in msg
def test_full_finetune_with_nccl_endpoint_uses_nccl(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(nccl=True)
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with patch.object(plugin, "_wire_multi_turn") as wire:
plugin.post_trainer_create(cfg, trainer)
wire.assert_called_once()
def test_full_finetune_with_http_endpoint_not_implemented_yet(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(http_full=True)
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(NotImplementedError) as ctx:
plugin.post_trainer_create(cfg, trainer)
assert "HTTP weight sync" in str(ctx.exception)
def test_full_finetune_with_no_routes_raises_with_full_param_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps()
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
msg = str(ctx.exception)
assert "no-op trainer" in msg
assert "init_communicator" in msg
assert "http_update_weights" in msg
def test_unprobed_caps_raises_with_probe_failure_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
# Plugin._vllm_caps left as default-None: the post_trainer_create
# branch falls back to a fresh VLLMWeightSyncCapabilities() with
# probed=False, so the error path should mention probing.
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
assert "could not probe" in str(ctx.exception)
class TestNemoGymE2E(unittest.TestCase):
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.
Exercises the full NemoGymDataProducer.produce() pipeline with mocked HTTP
responses, verifying that multi-turn Wordle agent responses are correctly
parsed into padded tensors with proper env_mask, logprobs, and rewards.
No GPU or NeMo Gym server required.
"""
# A realistic 2-turn agent /run response (guess + feedback + guess + done)
AGENT_RESPONSE = {
"response": {
"output": [
{
"type": "function_call",
"name": "guess_word",
"arguments": '{"guess": "crane"}',
"call_id": "call_1",
"id": "call_1",
"status": "completed",
"prompt_token_ids": [1, 2, 3, 4, 5],
"generation_token_ids": [10, 11, 12, 13],
"generation_log_probs": [-0.1, -0.2, -0.3, -0.4],
},
{
"type": "function_call_output",
"call_id": "call_1",
"output": '{"feedback":"XYGXY","guesses_remaining":5,"done":false}',
},
{
"type": "function_call",
"name": "guess_word",
"arguments": '{"guess": "slide"}',
"call_id": "call_2",
"id": "call_2",
"status": "completed",
# prompt = original(5) + gen1(4) + tool_output(3 tokens)
"prompt_token_ids": [1, 2, 3, 4, 5, 10, 11, 12, 13, 50, 51, 52],
"generation_token_ids": [20, 21, 22],
"generation_log_probs": [-0.5, -0.6, -0.7],
},
],
},
"reward": 0.42,
}
def _make_mock_trainer(self):
"""Create a minimal mock trainer with the attributes produce() needs."""
trainer = MagicMock()
trainer.accelerator.is_main_process = True
trainer.accelerator.device = "cpu"
trainer.max_completion_length = 512
trainer.temperature = 0.8
trainer.pad_token_id = 0
trainer.processing_class.eos_token_id = 2
trainer.processing_class.batch_decode.return_value = ["crane slide"]
return trainer
@patch("axolotl.integrations.nemo_gym.data_producer._call_agents")
def test_produce_returns_valid_rollout_dataset(self, mock_call_agents):
"""Full pipeline: produce() → _call_agents (mocked) → parse → RolloutDataset."""
from axolotl.integrations.nemo_gym.data_producer import NemoGymDataProducer
# Mock _call_agents — it's async, so return a coroutine
async def fake_call_agents(**kwargs):
return [self.AGENT_RESPONSE, self.AGENT_RESPONSE]
mock_call_agents.side_effect = fake_call_agents
# Build a minimal mock of GRPODataProducer's __init__ dependencies
# We can't easily call super().__init__, so we'll set attributes directly
producer = NemoGymDataProducer.__new__(NemoGymDataProducer)
producer._agent_servers = {"wordle_agent": "http://mock:9999"}
producer._dataset_lookup = {
"Play Wordle!": {
"agent_ref": {"name": "wordle_agent"},
"verify_extra": {
"responses_create_params": {
"input": [{"role": "user", "content": "Play Wordle!"}],
}
},
}
}
producer._request_timeout = 30
producer._num_generations = 2
# Mock the trainer
trainer = self._make_mock_trainer()
producer._trainer = trainer
# Mock the prompt iterator. RepeatSampler(mini_repeat_count=num_generations)
# pre-expands prompts, so the iterator yields num_generations=2 consecutive
# copies of each unique prompt — one entry per rollout.
_prompt_batch = [
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
]
producer._prompt_iter = iter([_prompt_batch])
producer._prompt_dl = [_prompt_batch]
# Call produce
result = producer.produce(model=MagicMock(), global_step=1)
# Verify result structure
assert result is not None
data = result._data
# Check tensor shapes — 2 rollouts (num_generations=2)
assert data["prompt_ids"].shape[0] == 2
assert data["completion_ids"].shape[0] == 2
assert data["completion_mask"].shape[0] == 2
assert data["sampling_per_token_logps"].shape[0] == 2
assert data["tool_mask"].shape[0] == 2
# Verify completion content — each rollout should have:
# gen1(4) + tool_output(3) + gen2(3) = 10 tokens
# (padded to same length across the batch, but both are same here)
comp_len = data["completion_mask"][0].sum().item()
assert comp_len == 10, f"Expected 10 completion tokens, got {comp_len}"
# Verify env_mask: gen1=1,1,1,1 tool=0,0,0 gen2=1,1,1
tool_mask = data["tool_mask"][0][:comp_len].tolist()
assert tool_mask == [1, 1, 1, 1, 0, 0, 0, 1, 1, 1]
# Verify logprobs are populated (use approx for float32 precision)
import pytest
logps = data["sampling_per_token_logps"][0][:comp_len].tolist()
assert logps[:4] == pytest.approx([-0.1, -0.2, -0.3, -0.4], abs=1e-6)
assert logps[4:7] == pytest.approx([0.0, 0.0, 0.0], abs=1e-6)
assert logps[7:10] == pytest.approx([-0.5, -0.6, -0.7], abs=1e-6)
# Verify rewards were injected into inputs
assert data["_deferred_inputs"][0]["env_reward"] == 0.42
assert data["_deferred_inputs"][1]["env_reward"] == 0.42
# Verify deferred scoring markers
assert data["_pending_policy_logps"] is True
@patch("axolotl.integrations.nemo_gym.data_producer._call_agents")
def test_produce_handles_failed_agent_response(self, mock_call_agents):
"""Failed agent responses should produce default (length-1) rollouts."""
from axolotl.integrations.nemo_gym.data_producer import NemoGymDataProducer
# One success, one failure — async mock
async def fake_call_agents(**kwargs):
return [
self.AGENT_RESPONSE,
{
"error": "Connection timeout",
"response": {"output": []},
"reward": 0.0,
},
]
mock_call_agents.side_effect = fake_call_agents
producer = NemoGymDataProducer.__new__(NemoGymDataProducer)
producer._agent_servers = {"wordle_agent": "http://mock:9999"}
producer._dataset_lookup = {}
producer._request_timeout = 30
producer._num_generations = 2
producer._trainer = self._make_mock_trainer()
# RepeatSampler pre-expands by num_generations=2.
_prompt_batch = [
{"prompt": [{"role": "user", "content": "Play!"}]},
{"prompt": [{"role": "user", "content": "Play!"}]},
]
producer._prompt_iter = iter([_prompt_batch])
producer._prompt_dl = [_prompt_batch]
result = producer.produce(model=MagicMock(), global_step=1)
assert result is not None
data = result._data
# Both rollouts present
assert data["completion_ids"].shape[0] == 2
# First rollout has real tokens, second has just eos (length 1)
mask0 = data["completion_mask"][0].sum().item()
mask1 = data["completion_mask"][1].sum().item()
assert mask0 == 10 # full response
assert mask1 == 1 # default fallback (just eos)
# Rewards: success=0.42, failure=0.0
assert data["_deferred_inputs"][0]["env_reward"] == 0.42
assert data["_deferred_inputs"][1]["env_reward"] == 0.0
@patch("axolotl.integrations.nemo_gym.rewards._get_verify_urls")
@patch("axolotl.integrations.nemo_gym.rewards.requests")
def test_reward_functions_chain(self, mock_requests, mock_get_urls):
"""Test that reward_env and reward_nemo_gym_verify can be used together."""
from axolotl.integrations.nemo_gym.rewards import (
reward_env,
reward_nemo_gym_verify,
)
completions = [[{"role": "assistant", "content": "crane"}]]
prompts = [[{"role": "user", "content": "Guess"}]]
# reward_env: passthrough from agent
env_result = reward_env(completions, prompts, env_reward=[0.42])
assert env_result == [0.42]
# reward_nemo_gym_verify: calls /verify
mock_get_urls.return_value = {"wordle": "http://localhost:9999/verify"}
mock_resp = MagicMock()
mock_resp.ok = True
mock_resp.json.return_value = {"reward": 0.75}
mock_requests.post.return_value = mock_resp
verify_result = reward_nemo_gym_verify(
completions,
prompts,
resources_server_ref=[{"name": "wordle"}],
verify_extra=[{}],
)
assert verify_result == [0.75]
# Both rewards can coexist (as they would in a multi-reward config)
combined = [e + v for e, v in zip(env_result, verify_result, strict=True)]
assert combined == [1.17]
class TestLoRASyncSetup(unittest.TestCase):
"""Tests for _setup_lora_sync delegation logic."""
def test_delegates_to_async_trainer(self):
"""When trainer has _sync_lora_adapter, the closure should delegate."""
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin.__new__(NemoGymPlugin)
trainer = MagicMock()
trainer._sync_lora_adapter = MagicMock()
trainer.vllm_generation = MagicMock()
plugin._setup_lora_sync(trainer)
# The closure should be installed
trainer.vllm_generation.sync_weights()
trainer._sync_lora_adapter.assert_called_once()
def test_check_lora_endpoint_skips_non_main_rank(self):
"""_check_lora_endpoint should not crash when vllm_client is absent (rank 1)."""
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
vllm_gen = MagicMock(spec=[]) # No attributes at all
# Should not raise
NemoGymPlugin._check_lora_endpoint(vllm_gen)
if __name__ == "__main__":
unittest.main()