* nemo gym integration with grpo wip * mostly working * cleanup * simplify * update docs * nemo gym support wip * cleanup * chore: lint * address PR review and add more tests * chore: lint * post merge lora fixes for CI (#3536) [skip ci] * post merge lora fixes for CI * handle lora kernel auto-enable for moe without grouped_mm * prefer not to import torch in schema validation * address pr comments, add timeout, add tests * roundup_power2_divisions not needed with newer pytorch versions (#3540) * roundup_power2_divisions not needed with newer pytorch versions * remove typo * update qwen3.5 moe 35b-a3b yaml for 5090 * more bug fixes * fix tests to match updated trainer * don't use fa2 for hooks test * reset plugins on the instance * retry download * fix references to renamed axolotl_cfg property on trainer * Fix ref to trainer cfg * fix: robust handling of race condition on patching check (#3543) [skip ci] * EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci] * EBFT wip * fixes * more fixeS * add missing strided module * ebft fixes for multi-turn * make ebft work with async * add example for ebft w qwen3.5 * fix for split thinking and update yaml for lora over linear attention only * enforce_eager for vllm arg in schema * fix sync weights * fix multi-gpu * handle updated sig for mm * ddp fixes * improve multi-gpu handling, don't calculate logits, adaptive completion length * chore: lint * chore: lint * support completion_mean * Address corereview feedback * clamp min IS ratio * Address PR code review * more fixes identified * address code review * Fix property from rebase conflict * fix for ebft sync and update docs * make trainer loss patch check a solo test --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
622 lines
23 KiB
Python
622 lines
23 KiB
Python
"""Unit tests for NeMo Gym integration.
|
|
|
|
Tests the core parsing, routing, reward, and plugin wiring logic
|
|
without requiring a running NeMo Gym server or GPU.
|
|
"""
|
|
|
|
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
|
|
class TestParseAgentResponse(unittest.TestCase):
|
|
"""Tests for _parse_agent_response in multi_turn.py."""
|
|
|
|
def _parse(self, response, eos_token_id=2):
|
|
from axolotl.integrations.nemo_gym.multi_turn import _parse_agent_response
|
|
|
|
return _parse_agent_response(response, eos_token_id)
|
|
|
|
def test_empty_response_returns_defaults(self):
|
|
result = self._parse({})
|
|
assert result["prompt_ids"] == [2]
|
|
assert result["completion_ids"] == [2]
|
|
assert result["env_mask"] == [0]
|
|
assert result["reward"] == 0.0
|
|
assert result["num_turns"] == 0
|
|
|
|
def test_error_response_returns_defaults(self):
|
|
result = self._parse({"error": "something broke"})
|
|
assert result["reward"] == 0.0
|
|
assert result["num_turns"] == 0
|
|
|
|
def test_single_turn_function_call(self):
|
|
response = {
|
|
"response": {
|
|
"output": [
|
|
{
|
|
"type": "function_call",
|
|
"name": "guess_word",
|
|
"arguments": '{"guess": "crane"}',
|
|
"call_id": "call_1",
|
|
"prompt_token_ids": [10, 20, 30],
|
|
"generation_token_ids": [40, 50],
|
|
"generation_log_probs": [-0.1, -0.2],
|
|
}
|
|
]
|
|
},
|
|
"reward": 0.5,
|
|
}
|
|
result = self._parse(response)
|
|
assert result["prompt_ids"] == [10, 20, 30]
|
|
assert result["completion_ids"] == [40, 50]
|
|
assert result["env_mask"] == [1, 1] # model tokens
|
|
assert result["logprobs"] == [-0.1, -0.2]
|
|
assert result["reward"] == 0.5
|
|
assert result["num_turns"] == 1
|
|
|
|
def test_multi_turn_preserves_env_mask(self):
|
|
"""Second turn's prompt tokens (tool results) get env_mask=0."""
|
|
response = {
|
|
"response": {
|
|
"output": [
|
|
{
|
|
"type": "function_call",
|
|
"prompt_token_ids": [10, 20],
|
|
"generation_token_ids": [30, 31],
|
|
"generation_log_probs": [-0.1, -0.2],
|
|
},
|
|
{
|
|
"type": "function_call_output",
|
|
"output": '{"feedback": "XYGXY"}',
|
|
},
|
|
{
|
|
"type": "function_call",
|
|
# prompt includes original + gen + tool output
|
|
"prompt_token_ids": [10, 20, 30, 31, 100, 101, 102],
|
|
"generation_token_ids": [40, 41],
|
|
"generation_log_probs": [-0.3, -0.4],
|
|
},
|
|
]
|
|
},
|
|
"reward": 0.3,
|
|
}
|
|
result = self._parse(response)
|
|
assert result["prompt_ids"] == [10, 20]
|
|
# completion = gen1 + tool_result + gen2
|
|
assert result["completion_ids"] == [30, 31, 100, 101, 102, 40, 41]
|
|
# env_mask: gen1=model(1), tool=env(0), gen2=model(1)
|
|
assert result["env_mask"] == [1, 1, 0, 0, 0, 1, 1]
|
|
assert result["num_turns"] == 2
|
|
|
|
def test_empty_output_preserves_reward(self):
|
|
response = {
|
|
"response": {"output": []},
|
|
"reward": 0.42,
|
|
}
|
|
result = self._parse(response)
|
|
assert result["reward"] == 0.42
|
|
|
|
def test_message_only_output(self):
|
|
"""A message with text but no function calls."""
|
|
response = {
|
|
"response": {
|
|
"output": [
|
|
{
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"content": [
|
|
{"type": "output_text", "text": "I'll guess crane."}
|
|
],
|
|
"prompt_token_ids": [10, 20],
|
|
"generation_token_ids": [30, 31, 32],
|
|
"generation_log_probs": [-0.1, -0.2, -0.3],
|
|
}
|
|
]
|
|
},
|
|
"reward": 0.1,
|
|
}
|
|
result = self._parse(response)
|
|
assert result["num_turns"] == 1
|
|
assert result["completion_ids"] == [30, 31, 32]
|
|
assert result["env_mask"] == [1, 1, 1]
|
|
|
|
|
|
class TestRewardEnv(unittest.TestCase):
|
|
"""Tests for reward_env passthrough function."""
|
|
|
|
def test_with_list_rewards(self):
|
|
from axolotl.integrations.nemo_gym.rewards import reward_env
|
|
|
|
result = reward_env([["comp1"], ["comp2"]], env_reward=[0.5, 0.8])
|
|
assert result == [0.5, 0.8]
|
|
|
|
def test_with_scalar_reward(self):
|
|
from axolotl.integrations.nemo_gym.rewards import reward_env
|
|
|
|
result = reward_env([["comp1"], ["comp2"]], env_reward=0.7)
|
|
assert result == [0.7, 0.7]
|
|
|
|
def test_missing_reward_returns_zeros(self):
|
|
from axolotl.integrations.nemo_gym.rewards import reward_env
|
|
|
|
result = reward_env([["comp1"], ["comp2"]])
|
|
assert result == [0.0, 0.0]
|
|
|
|
|
|
class TestRewardNemoGymVerify(unittest.TestCase):
|
|
"""Tests for reward_nemo_gym_verify with mocked HTTP."""
|
|
|
|
@patch("axolotl.integrations.nemo_gym.rewards._get_verify_urls")
|
|
@patch("axolotl.integrations.nemo_gym.rewards.requests")
|
|
def test_calls_verify_endpoint(self, mock_requests, mock_get_urls):
|
|
from axolotl.integrations.nemo_gym.rewards import reward_nemo_gym_verify
|
|
|
|
mock_get_urls.return_value = {"wordle": "http://localhost:9999/verify"}
|
|
mock_resp = MagicMock()
|
|
mock_resp.ok = True
|
|
mock_resp.json.return_value = {"reward": 0.75}
|
|
mock_requests.post.return_value = mock_resp
|
|
|
|
result = reward_nemo_gym_verify(
|
|
completions=[[{"role": "assistant", "content": "crane"}]],
|
|
prompts=[[{"role": "user", "content": "Guess a word"}]],
|
|
resources_server_ref=[{"name": "wordle"}],
|
|
verify_extra=[{}],
|
|
)
|
|
|
|
assert result == [0.75]
|
|
mock_requests.post.assert_called_once()
|
|
|
|
@patch("axolotl.integrations.nemo_gym.rewards._get_verify_urls")
|
|
def test_missing_server_returns_zero(self, mock_get_urls):
|
|
from axolotl.integrations.nemo_gym.rewards import reward_nemo_gym_verify
|
|
|
|
mock_get_urls.return_value = {}
|
|
|
|
result = reward_nemo_gym_verify(
|
|
completions=[[{"role": "assistant", "content": "crane"}]],
|
|
prompts=[[{"role": "user", "content": "Guess"}]],
|
|
resources_server_ref=[{"name": "unknown_server"}],
|
|
verify_extra=[{}],
|
|
)
|
|
assert result == [0.0]
|
|
|
|
|
|
class TestNormalizeHost(unittest.TestCase):
|
|
"""Tests for server.py _normalize_host helper."""
|
|
|
|
def test_zero_addr_normalized(self):
|
|
from axolotl.integrations.nemo_gym.server import _normalize_host
|
|
|
|
assert _normalize_host("0.0.0.0") == "127.0.0.1"
|
|
|
|
def test_localhost_normalized(self):
|
|
from axolotl.integrations.nemo_gym.server import _normalize_host
|
|
|
|
assert _normalize_host("localhost") == "127.0.0.1"
|
|
|
|
def test_loopback_passthrough(self):
|
|
from axolotl.integrations.nemo_gym.server import _normalize_host
|
|
|
|
assert _normalize_host("127.0.0.1") == "127.0.0.1"
|
|
|
|
def test_custom_fallback(self):
|
|
from axolotl.integrations.nemo_gym.server import _normalize_host
|
|
|
|
assert _normalize_host("0.0.0.0", fallback="10.0.0.1") == "10.0.0.1"
|
|
|
|
def test_real_ip_passthrough(self):
|
|
from axolotl.integrations.nemo_gym.server import _normalize_host
|
|
|
|
assert _normalize_host("192.168.1.50") == "192.168.1.50"
|
|
|
|
|
|
class TestDatasetLookupKeying(unittest.TestCase):
|
|
"""Verify dataset lookup uses last message content as key."""
|
|
|
|
def test_single_message_prompt(self):
|
|
"""Single-message prompt: [0] == [-1], both work."""
|
|
prompt = [{"role": "user", "content": "Play Wordle!"}]
|
|
assert prompt[0]["content"] == prompt[-1]["content"]
|
|
|
|
def test_multi_message_prompt_uses_last(self):
|
|
"""Multi-message prompt: must use [-1] to match data_producer lookup."""
|
|
prompt = [
|
|
{"role": "system", "content": "You are a game player."},
|
|
{"role": "user", "content": "Play Wordle!"},
|
|
]
|
|
# data_producer.py line 92 uses prompt[-1]
|
|
key = prompt[-1]["content"]
|
|
assert key == "Play Wordle!"
|
|
# Old code used prompt[0] which would be wrong here
|
|
assert prompt[0]["content"] != key
|
|
|
|
|
|
class TestAgentRefPreservation(unittest.TestCase):
|
|
"""Verify agent_ref is preserved through the dispatch chain."""
|
|
|
|
def test_data_producer_preserves_agent_ref(self):
|
|
"""Simulates the data_producer lookup logic."""
|
|
# Simulate what plugin.py builds
|
|
dataset_lookup = {
|
|
"Play Wordle!": {
|
|
"prompt": [{"role": "user", "content": "Play Wordle!"}],
|
|
"agent_ref": {"name": "wordle_simple_agent"},
|
|
"verify_extra": {
|
|
"responses_create_params": {
|
|
"input": [{"role": "user", "content": "Play Wordle!"}]
|
|
}
|
|
},
|
|
}
|
|
}
|
|
|
|
# Simulate data_producer.py logic (after fix)
|
|
prompt_text = "Play Wordle!"
|
|
full_item = dataset_lookup.get(prompt_text, {})
|
|
item = full_item.get("verify_extra", {})
|
|
if "agent_ref" in full_item and "agent_ref" not in item:
|
|
item["agent_ref"] = full_item["agent_ref"]
|
|
|
|
assert "agent_ref" in item
|
|
assert item["agent_ref"]["name"] == "wordle_simple_agent"
|
|
|
|
def test_multi_turn_preserves_agent_ref(self):
|
|
"""Simulates the multi_turn.py dispatch logic."""
|
|
dataset_lookup = {
|
|
"Play Wordle!": {
|
|
"agent_ref": {"name": "wordle_simple_agent"},
|
|
"verify_extra": {
|
|
"responses_create_params": {
|
|
"input": [{"role": "user", "content": "Play Wordle!"}]
|
|
}
|
|
},
|
|
}
|
|
}
|
|
|
|
# Simulate multi_turn.py logic (after fix)
|
|
prompt_str = "Play Wordle!"
|
|
full_item = None
|
|
for key, val in dataset_lookup.items():
|
|
if isinstance(key, str) and prompt_str == key:
|
|
full_item = val
|
|
break
|
|
|
|
dispatched = full_item.get("verify_extra", full_item)
|
|
if isinstance(dispatched, dict) and "agent_ref" not in dispatched:
|
|
agent_ref = full_item.get("agent_ref")
|
|
if agent_ref:
|
|
dispatched = {**dispatched, "agent_ref": agent_ref}
|
|
|
|
assert "agent_ref" in dispatched
|
|
assert dispatched["agent_ref"]["name"] == "wordle_simple_agent"
|
|
|
|
|
|
class TestCallAgentsRouting(unittest.TestCase):
|
|
"""Tests for _call_agents routing via agent_ref."""
|
|
|
|
def test_routes_to_correct_agent(self):
|
|
"""Items with agent_ref should route to the matching agent server."""
|
|
|
|
agent_servers = {
|
|
"wordle_agent": "http://localhost:11111",
|
|
"math_agent": "http://localhost:22222",
|
|
}
|
|
|
|
items = [
|
|
{
|
|
"agent_ref": {"name": "wordle_agent"},
|
|
"responses_create_params": {
|
|
"input": [{"role": "user", "content": "Play"}]
|
|
},
|
|
}
|
|
]
|
|
|
|
# We can't actually call the agent, but verify the URL resolution
|
|
# by checking _call_agents builds the right request
|
|
# The function uses aiohttp — just verify agent_ref lookup works
|
|
item = items[0]
|
|
agent_ref = item.get("agent_ref", {})
|
|
agent_name = agent_ref.get("name", "")
|
|
agent_url = agent_servers.get(agent_name, "")
|
|
assert agent_url == "http://localhost:11111"
|
|
|
|
def test_fallback_to_first_agent(self):
|
|
"""Items without agent_ref should use first available agent."""
|
|
agent_servers = {"default_agent": "http://localhost:33333"}
|
|
item = {
|
|
"responses_create_params": {"input": [{"role": "user", "content": "Hello"}]}
|
|
}
|
|
agent_ref = item.get("agent_ref", {})
|
|
agent_name = agent_ref.get("name", "")
|
|
agent_url = agent_servers.get(agent_name, "")
|
|
if not agent_url and agent_servers:
|
|
agent_url = next(iter(agent_servers.values()))
|
|
assert agent_url == "http://localhost:33333"
|
|
|
|
|
|
class TestPluginDefaults(unittest.TestCase):
|
|
"""Tests for plugin config enforcement."""
|
|
|
|
def test_dataloader_num_workers_forced_to_zero(self):
|
|
"""Plugin should set dataloader_num_workers=0 for NeMo Gym."""
|
|
|
|
# Simulate the plugin logic
|
|
class FakeCfg:
|
|
dataloader_num_workers = 4
|
|
nemo_gym_multi_turn = True
|
|
|
|
cfg = FakeCfg()
|
|
# Replicate plugin.get_training_args logic
|
|
if getattr(cfg, "dataloader_num_workers", None) not in (None, 0):
|
|
pass # would log warning
|
|
cfg.dataloader_num_workers = 0
|
|
assert cfg.dataloader_num_workers == 0
|
|
|
|
def test_dataloader_num_workers_none_stays_zero(self):
|
|
class FakeCfg:
|
|
dataloader_num_workers = None
|
|
|
|
cfg = FakeCfg()
|
|
cfg.dataloader_num_workers = 0
|
|
assert cfg.dataloader_num_workers == 0
|
|
|
|
|
|
class TestNemoGymE2E(unittest.TestCase):
|
|
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.
|
|
|
|
Exercises the full NemoGymDataProducer.produce() pipeline with mocked HTTP
|
|
responses, verifying that multi-turn Wordle agent responses are correctly
|
|
parsed into padded tensors with proper env_mask, logprobs, and rewards.
|
|
No GPU or NeMo Gym server required.
|
|
"""
|
|
|
|
# A realistic 2-turn agent /run response (guess + feedback + guess + done)
|
|
AGENT_RESPONSE = {
|
|
"response": {
|
|
"output": [
|
|
{
|
|
"type": "function_call",
|
|
"name": "guess_word",
|
|
"arguments": '{"guess": "crane"}',
|
|
"call_id": "call_1",
|
|
"id": "call_1",
|
|
"status": "completed",
|
|
"prompt_token_ids": [1, 2, 3, 4, 5],
|
|
"generation_token_ids": [10, 11, 12, 13],
|
|
"generation_log_probs": [-0.1, -0.2, -0.3, -0.4],
|
|
},
|
|
{
|
|
"type": "function_call_output",
|
|
"call_id": "call_1",
|
|
"output": '{"feedback":"XYGXY","guesses_remaining":5,"done":false}',
|
|
},
|
|
{
|
|
"type": "function_call",
|
|
"name": "guess_word",
|
|
"arguments": '{"guess": "slide"}',
|
|
"call_id": "call_2",
|
|
"id": "call_2",
|
|
"status": "completed",
|
|
# prompt = original(5) + gen1(4) + tool_output(3 tokens)
|
|
"prompt_token_ids": [1, 2, 3, 4, 5, 10, 11, 12, 13, 50, 51, 52],
|
|
"generation_token_ids": [20, 21, 22],
|
|
"generation_log_probs": [-0.5, -0.6, -0.7],
|
|
},
|
|
],
|
|
},
|
|
"reward": 0.42,
|
|
}
|
|
|
|
def _make_mock_trainer(self):
|
|
"""Create a minimal mock trainer with the attributes produce() needs."""
|
|
trainer = MagicMock()
|
|
trainer.accelerator.is_main_process = True
|
|
trainer.accelerator.device = "cpu"
|
|
trainer.max_completion_length = 512
|
|
trainer.temperature = 0.8
|
|
trainer.pad_token_id = 0
|
|
trainer.processing_class.eos_token_id = 2
|
|
trainer.processing_class.batch_decode.return_value = ["crane slide"]
|
|
return trainer
|
|
|
|
@patch("axolotl.integrations.nemo_gym.data_producer._call_agents")
|
|
def test_produce_returns_valid_rollout_dataset(self, mock_call_agents):
|
|
"""Full pipeline: produce() → _call_agents (mocked) → parse → RolloutDataset."""
|
|
|
|
from axolotl.integrations.nemo_gym.data_producer import NemoGymDataProducer
|
|
|
|
# Mock _call_agents — it's async, so return a coroutine
|
|
async def fake_call_agents(**kwargs):
|
|
return [self.AGENT_RESPONSE, self.AGENT_RESPONSE]
|
|
|
|
mock_call_agents.side_effect = fake_call_agents
|
|
|
|
# Build a minimal mock of GRPODataProducer's __init__ dependencies
|
|
# We can't easily call super().__init__, so we'll set attributes directly
|
|
producer = NemoGymDataProducer.__new__(NemoGymDataProducer)
|
|
producer._agent_servers = {"wordle_agent": "http://mock:9999"}
|
|
producer._dataset_lookup = {
|
|
"Play Wordle!": {
|
|
"agent_ref": {"name": "wordle_agent"},
|
|
"verify_extra": {
|
|
"responses_create_params": {
|
|
"input": [{"role": "user", "content": "Play Wordle!"}],
|
|
}
|
|
},
|
|
}
|
|
}
|
|
producer._request_timeout = 30
|
|
producer._num_generations = 2
|
|
|
|
# Mock the trainer
|
|
trainer = self._make_mock_trainer()
|
|
producer._trainer = trainer
|
|
|
|
# Mock the prompt iterator (returns a batch of 1 input)
|
|
producer._prompt_iter = iter(
|
|
[
|
|
[
|
|
{
|
|
"prompt": [{"role": "user", "content": "Play Wordle!"}],
|
|
}
|
|
]
|
|
]
|
|
)
|
|
producer._prompt_dl = [
|
|
[{"prompt": [{"role": "user", "content": "Play Wordle!"}]}]
|
|
]
|
|
|
|
# Call produce
|
|
result = producer.produce(model=MagicMock(), global_step=1)
|
|
|
|
# Verify result structure
|
|
assert result is not None
|
|
data = result._data
|
|
|
|
# Check tensor shapes — 2 rollouts (num_generations=2)
|
|
assert data["prompt_ids"].shape[0] == 2
|
|
assert data["completion_ids"].shape[0] == 2
|
|
assert data["completion_mask"].shape[0] == 2
|
|
assert data["sampling_per_token_logps"].shape[0] == 2
|
|
assert data["tool_mask"].shape[0] == 2
|
|
|
|
# Verify completion content — each rollout should have:
|
|
# gen1(4) + tool_output(3) + gen2(3) = 10 tokens
|
|
# (padded to same length across the batch, but both are same here)
|
|
comp_len = data["completion_mask"][0].sum().item()
|
|
assert comp_len == 10, f"Expected 10 completion tokens, got {comp_len}"
|
|
|
|
# Verify env_mask: gen1=1,1,1,1 tool=0,0,0 gen2=1,1,1
|
|
tool_mask = data["tool_mask"][0][:comp_len].tolist()
|
|
assert tool_mask == [1, 1, 1, 1, 0, 0, 0, 1, 1, 1]
|
|
|
|
# Verify logprobs are populated (use approx for float32 precision)
|
|
import pytest
|
|
|
|
logps = data["sampling_per_token_logps"][0][:comp_len].tolist()
|
|
assert logps[:4] == pytest.approx([-0.1, -0.2, -0.3, -0.4], abs=1e-6)
|
|
assert logps[4:7] == pytest.approx([0.0, 0.0, 0.0], abs=1e-6)
|
|
assert logps[7:10] == pytest.approx([-0.5, -0.6, -0.7], abs=1e-6)
|
|
|
|
# Verify rewards were injected into inputs
|
|
assert data["_deferred_inputs"][0]["env_reward"] == 0.42
|
|
assert data["_deferred_inputs"][1]["env_reward"] == 0.42
|
|
|
|
# Verify deferred scoring markers
|
|
assert data["_pending_policy_logps"] is True
|
|
|
|
@patch("axolotl.integrations.nemo_gym.data_producer._call_agents")
|
|
def test_produce_handles_failed_agent_response(self, mock_call_agents):
|
|
"""Failed agent responses should produce default (length-1) rollouts."""
|
|
|
|
from axolotl.integrations.nemo_gym.data_producer import NemoGymDataProducer
|
|
|
|
# One success, one failure — async mock
|
|
async def fake_call_agents(**kwargs):
|
|
return [
|
|
self.AGENT_RESPONSE,
|
|
{
|
|
"error": "Connection timeout",
|
|
"response": {"output": []},
|
|
"reward": 0.0,
|
|
},
|
|
]
|
|
|
|
mock_call_agents.side_effect = fake_call_agents
|
|
|
|
producer = NemoGymDataProducer.__new__(NemoGymDataProducer)
|
|
producer._agent_servers = {"wordle_agent": "http://mock:9999"}
|
|
producer._dataset_lookup = {}
|
|
producer._request_timeout = 30
|
|
producer._num_generations = 2
|
|
producer._trainer = self._make_mock_trainer()
|
|
producer._prompt_iter = iter(
|
|
[[{"prompt": [{"role": "user", "content": "Play!"}]}]]
|
|
)
|
|
producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]]
|
|
|
|
result = producer.produce(model=MagicMock(), global_step=1)
|
|
|
|
assert result is not None
|
|
data = result._data
|
|
|
|
# Both rollouts present
|
|
assert data["completion_ids"].shape[0] == 2
|
|
|
|
# First rollout has real tokens, second has just eos (length 1)
|
|
mask0 = data["completion_mask"][0].sum().item()
|
|
mask1 = data["completion_mask"][1].sum().item()
|
|
assert mask0 == 10 # full response
|
|
assert mask1 == 1 # default fallback (just eos)
|
|
|
|
# Rewards: success=0.42, failure=0.0
|
|
assert data["_deferred_inputs"][0]["env_reward"] == 0.42
|
|
assert data["_deferred_inputs"][1]["env_reward"] == 0.0
|
|
|
|
@patch("axolotl.integrations.nemo_gym.rewards._get_verify_urls")
|
|
@patch("axolotl.integrations.nemo_gym.rewards.requests")
|
|
def test_reward_functions_chain(self, mock_requests, mock_get_urls):
|
|
"""Test that reward_env and reward_nemo_gym_verify can be used together."""
|
|
from axolotl.integrations.nemo_gym.rewards import (
|
|
reward_env,
|
|
reward_nemo_gym_verify,
|
|
)
|
|
|
|
completions = [[{"role": "assistant", "content": "crane"}]]
|
|
prompts = [[{"role": "user", "content": "Guess"}]]
|
|
|
|
# reward_env: passthrough from agent
|
|
env_result = reward_env(completions, prompts, env_reward=[0.42])
|
|
assert env_result == [0.42]
|
|
|
|
# reward_nemo_gym_verify: calls /verify
|
|
mock_get_urls.return_value = {"wordle": "http://localhost:9999/verify"}
|
|
mock_resp = MagicMock()
|
|
mock_resp.ok = True
|
|
mock_resp.json.return_value = {"reward": 0.75}
|
|
mock_requests.post.return_value = mock_resp
|
|
|
|
verify_result = reward_nemo_gym_verify(
|
|
completions,
|
|
prompts,
|
|
resources_server_ref=[{"name": "wordle"}],
|
|
verify_extra=[{}],
|
|
)
|
|
assert verify_result == [0.75]
|
|
|
|
# Both rewards can coexist (as they would in a multi-reward config)
|
|
combined = [e + v for e, v in zip(env_result, verify_result, strict=True)]
|
|
assert combined == [1.17]
|
|
|
|
|
|
class TestLoRASyncSetup(unittest.TestCase):
|
|
"""Tests for _setup_lora_sync delegation logic."""
|
|
|
|
def test_delegates_to_async_trainer(self):
|
|
"""When trainer has _sync_lora_adapter, the closure should delegate."""
|
|
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
|
|
|
plugin = NemoGymPlugin.__new__(NemoGymPlugin)
|
|
|
|
trainer = MagicMock()
|
|
trainer._sync_lora_adapter = MagicMock()
|
|
trainer.vllm_generation = MagicMock()
|
|
|
|
plugin._setup_lora_sync(trainer)
|
|
|
|
# The closure should be installed
|
|
trainer.vllm_generation.sync_weights()
|
|
trainer._sync_lora_adapter.assert_called_once()
|
|
|
|
def test_check_lora_endpoint_skips_non_main_rank(self):
|
|
"""_check_lora_endpoint should not crash when vllm_client is absent (rank 1)."""
|
|
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
|
|
|
vllm_gen = MagicMock(spec=[]) # No attributes at all
|
|
# Should not raise
|
|
NemoGymPlugin._check_lora_endpoint(vllm_gen)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|