fix async prefetch with nemogym (#3606)

This commit is contained in:
Wing Lian
2026-04-22 09:05:46 -04:00
committed by GitHub
parent 05113bc91a
commit 7420fd4de6
16 changed files with 2388 additions and 135 deletions

View File

@@ -361,6 +361,329 @@ class TestPluginDefaults(unittest.TestCase):
assert cfg.dataloader_num_workers == 0
class TestSelectWeightSyncTransport(unittest.TestCase):
"""Pure-logic table tests for ``select_weight_sync_transport``."""
def _caps(self, **kwargs):
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
c = VLLMWeightSyncCapabilities(probed=True)
for k, v in kwargs.items():
setattr(c, k, v)
return c
def test_lora_with_native_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(lora_filesystem=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
== "lora_filesystem"
)
def test_lora_with_axolotl_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(lora_axolotl=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
== "lora_filesystem"
)
def test_lora_falls_back_to_nccl_when_no_lora_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(nccl=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
== "nccl"
)
def test_full_param_prefers_nccl(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(nccl=True, http_full=True)
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "nccl"
)
def test_full_param_falls_back_to_http(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(http_full=True)
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "http_full"
)
def test_full_param_no_routes_returns_none(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps() # all False
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "none"
)
def test_lora_no_routes_returns_none(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps()
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
== "none"
)
class TestProbeVllmWeightSync(unittest.TestCase):
"""``probe_vllm_weight_sync`` reads a vLLM ``/openapi.json`` and reports caps."""
def test_stock_vllm_with_lora_enabled(self):
"""Stock ``vllm serve --enable-lora`` exposes only LoRA endpoints."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/v1/models": {"get": {}},
"/v1/load_lora_adapter": {"post": {}},
"/v1/unload_lora_adapter": {"post": {}},
"/v1/completions": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.lora_filesystem is True
assert caps.lora_axolotl is False
assert caps.nccl is False
assert caps.http_full is False
def test_axolotl_serve_lora_full_capabilities(self):
"""``axolotl vllm-serve`` exposes NCCL + LoRA + HTTP full sync."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/init_communicator/": {"post": {}},
"/update_named_param/": {"post": {}},
"/batch_update_named_params/": {"post": {}},
"/set_lora_adapter/": {"post": {}},
"/clear_lora_adapter/": {"post": {}},
"/http_update_weights/": {"post": {}},
"/v1/load_lora_adapter": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.nccl is True
assert caps.lora_axolotl is True
assert caps.lora_filesystem is True
assert caps.http_full is True
def test_trl_vllm_serve_nccl_only(self):
"""``trl vllm-serve`` exposes NCCL routes but not LoRA filesystem."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/init_communicator/": {"post": {}},
"/update_named_param/": {"post": {}},
"/batch_update_named_params/": {"post": {}},
"/close_communicator/": {"post": {}},
"/generate/": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.nccl is True
assert caps.lora_filesystem is False
assert caps.lora_axolotl is False
assert caps.http_full is False
def test_unreachable_server_records_error(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
with patch("requests.get") as mock_get:
mock_get.side_effect = ConnectionError("Connection refused")
caps = probe_vllm_weight_sync("http://localhost:9999")
assert caps.probed is False
assert caps.probe_error is not None
assert "ConnectionError" in caps.probe_error
assert caps.nccl is False
assert caps.lora_filesystem is False
class TestPluginWeightSyncEnforcement(unittest.TestCase):
"""End-to-end test of post_trainer_create's transport-selection branch.
The plugin used to silently no-op weight sync when ``vllm_lora_sync: false``,
leaving the trainer learning in isolation while vLLM kept serving the
unmodified base model. After the fix:
- LoRA + LoRA-loading endpoint → installs filesystem LoRA sync
- LoRA + only NCCL endpoint → uses NCCL broadcast
- Full FT + NCCL endpoint → uses NCCL broadcast (standard TRL flow)
- Full FT + HTTP endpoint → raises NotImplementedError (step 3)
- No usable transport → raises ValueError with a precise diagnosis
"""
@staticmethod
def _fake_cfg(adapter, vllm_lora_sync):
class FakeTRL:
pass
class FakeCfg:
pass
trl = FakeTRL()
trl.vllm_lora_sync = vllm_lora_sync
trl.vllm_server_host = "127.0.0.1"
trl.vllm_server_port = 8000
cfg = FakeCfg()
cfg.nemo_gym_enabled = True
cfg.nemo_gym_model_name = None
cfg.base_model = "test/model"
cfg.nemo_gym_verify_timeout = 30
cfg.nemo_gym_multi_turn = True
cfg.adapter = adapter
cfg.trl = trl
return cfg
@staticmethod
def _fake_trainer():
class FakeVLLMGen:
sync_weights = staticmethod(lambda: None)
class FakeTrainer:
vllm_generation = FakeVLLMGen()
return FakeTrainer()
@staticmethod
def _caps(**kwargs):
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
c = VLLMWeightSyncCapabilities(probed=True)
for k, v in kwargs.items():
setattr(c, k, v)
return c
def test_lora_with_lora_endpoint_installs_filesystem_sync(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(lora_filesystem=True)
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
trainer = self._fake_trainer()
with (
patch.object(plugin, "_setup_lora_sync") as setup,
patch.object(plugin, "_check_lora_endpoint") as check,
patch.object(plugin, "_wire_multi_turn") as wire,
):
plugin.post_trainer_create(cfg, trainer)
setup.assert_called_once()
check.assert_called_once()
wire.assert_called_once()
def test_lora_with_no_routes_raises_with_lora_specific_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps() # all False, but probed
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
msg = str(ctx.exception)
assert "no-op trainer" in msg
assert "load_lora_adapter" in msg
assert "VLLM_ALLOW_RUNTIME_LORA_UPDATING" in msg
def test_full_finetune_with_nccl_endpoint_uses_nccl(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(nccl=True)
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with patch.object(plugin, "_wire_multi_turn") as wire:
plugin.post_trainer_create(cfg, trainer)
wire.assert_called_once()
def test_full_finetune_with_http_endpoint_not_implemented_yet(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(http_full=True)
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(NotImplementedError) as ctx:
plugin.post_trainer_create(cfg, trainer)
assert "HTTP weight sync" in str(ctx.exception)
def test_full_finetune_with_no_routes_raises_with_full_param_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps()
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
msg = str(ctx.exception)
assert "no-op trainer" in msg
assert "init_communicator" in msg
assert "http_update_weights" in msg
def test_unprobed_caps_raises_with_probe_failure_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
# Plugin._vllm_caps left as default-None: the post_trainer_create
# branch falls back to a fresh VLLMWeightSyncCapabilities() with
# probed=False, so the error path should mention probing.
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
assert "could not probe" in str(ctx.exception)
class TestNemoGymE2E(unittest.TestCase):
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.
@@ -452,19 +775,15 @@ class TestNemoGymE2E(unittest.TestCase):
trainer = self._make_mock_trainer()
producer._trainer = trainer
# Mock the prompt iterator (returns a batch of 1 input)
producer._prompt_iter = iter(
[
[
{
"prompt": [{"role": "user", "content": "Play Wordle!"}],
}
]
]
)
producer._prompt_dl = [
[{"prompt": [{"role": "user", "content": "Play Wordle!"}]}]
# Mock the prompt iterator. RepeatSampler(mini_repeat_count=num_generations)
# pre-expands prompts, so the iterator yields num_generations=2 consecutive
# copies of each unique prompt — one entry per rollout.
_prompt_batch = [
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
]
producer._prompt_iter = iter([_prompt_batch])
producer._prompt_dl = [_prompt_batch]
# Call produce
result = producer.produce(model=MagicMock(), global_step=1)
@@ -530,10 +849,13 @@ class TestNemoGymE2E(unittest.TestCase):
producer._request_timeout = 30
producer._num_generations = 2
producer._trainer = self._make_mock_trainer()
producer._prompt_iter = iter(
[[{"prompt": [{"role": "user", "content": "Play!"}]}]]
)
producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]]
# RepeatSampler pre-expands by num_generations=2.
_prompt_batch = [
{"prompt": [{"role": "user", "content": "Play!"}]},
{"prompt": [{"role": "user", "content": "Play!"}]},
]
producer._prompt_iter = iter([_prompt_batch])
producer._prompt_dl = [_prompt_batch]
result = producer.produce(model=MagicMock(), global_step=1)