probe vLLM weight-sync routes and select transport per server
The plugin used to unconditionally monkey-patch
VLLMClient.init_communicator to a no-op AND silently no-op
sync_weights when vllm_lora_sync was off. Combined, this turned the
trainer into a functional no-op whenever (a) the user ran NeMo Gym
+ LoRA without remembering to set vllm_lora_sync=true or (b) the
user ran NeMo Gym + full fine-tune (which had no working sync path
under the old code).
Replace both patches with:
1. A probe of the configured vLLM server's /openapi.json at
pre_model_load. Three transports are recognized:
- NCCL (/init_communicator/ + /update_named_param/) — TRL serve
and axolotl vllm-serve both expose this
- LoRA filesystem (/v1/load_lora_adapter or /set_lora_adapter/)
- HTTP base64 full-weight (/http_update_weights/) — axolotl
vllm-serve only
2. A pure-logic ``select_weight_sync_transport`` that picks the
right one for (server caps × adapter type).
3. ``init_communicator`` is only patched out when the server has no
NCCL routes; against TRL/axolotl serve modules it stays live so
full-finetune NCCL sync works.
4. ``post_trainer_create`` uses the selection table to install LoRA
filesystem sync OR leave the standard NCCL flow alone OR raise
NotImplementedError (HTTP — pending) OR raise a precise diagnosis
when no transport is viable. No more silent no-op trainers.
This commit is contained in:
@@ -361,6 +361,329 @@ class TestPluginDefaults(unittest.TestCase):
|
||||
assert cfg.dataloader_num_workers == 0
|
||||
|
||||
|
||||
class TestSelectWeightSyncTransport(unittest.TestCase):
|
||||
"""Pure-logic table tests for ``select_weight_sync_transport``."""
|
||||
|
||||
def _caps(self, **kwargs):
|
||||
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
|
||||
|
||||
c = VLLMWeightSyncCapabilities(probed=True)
|
||||
for k, v in kwargs.items():
|
||||
setattr(c, k, v)
|
||||
return c
|
||||
|
||||
def test_lora_with_native_endpoint(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(lora_filesystem=True)
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
|
||||
== "lora_filesystem"
|
||||
)
|
||||
|
||||
def test_lora_with_axolotl_endpoint(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(lora_axolotl=True)
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
|
||||
== "lora_filesystem"
|
||||
)
|
||||
|
||||
def test_lora_falls_back_to_nccl_when_no_lora_endpoint(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(nccl=True)
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
|
||||
== "nccl"
|
||||
)
|
||||
|
||||
def test_full_param_prefers_nccl(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(nccl=True, http_full=True)
|
||||
assert (
|
||||
select_weight_sync_transport(
|
||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||
)
|
||||
== "nccl"
|
||||
)
|
||||
|
||||
def test_full_param_falls_back_to_http(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(http_full=True)
|
||||
assert (
|
||||
select_weight_sync_transport(
|
||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||
)
|
||||
== "http_full"
|
||||
)
|
||||
|
||||
def test_full_param_no_routes_returns_none(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps() # all False
|
||||
assert (
|
||||
select_weight_sync_transport(
|
||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||
)
|
||||
== "none"
|
||||
)
|
||||
|
||||
def test_lora_no_routes_returns_none(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps()
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
|
||||
== "none"
|
||||
)
|
||||
|
||||
|
||||
class TestProbeVllmWeightSync(unittest.TestCase):
|
||||
"""``probe_vllm_weight_sync`` reads a vLLM ``/openapi.json`` and reports caps."""
|
||||
|
||||
def test_stock_vllm_with_lora_enabled(self):
|
||||
"""Stock ``vllm serve --enable-lora`` exposes only LoRA endpoints."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
spec = {
|
||||
"paths": {
|
||||
"/v1/models": {"get": {}},
|
||||
"/v1/load_lora_adapter": {"post": {}},
|
||||
"/v1/unload_lora_adapter": {"post": {}},
|
||||
"/v1/completions": {"post": {}},
|
||||
}
|
||||
}
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.return_value.raise_for_status = lambda: None
|
||||
mock_get.return_value.json = lambda: spec
|
||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||
|
||||
assert caps.probed is True
|
||||
assert caps.lora_filesystem is True
|
||||
assert caps.lora_axolotl is False
|
||||
assert caps.nccl is False
|
||||
assert caps.http_full is False
|
||||
|
||||
def test_axolotl_serve_lora_full_capabilities(self):
|
||||
"""``axolotl vllm-serve`` exposes NCCL + LoRA + HTTP full sync."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
spec = {
|
||||
"paths": {
|
||||
"/init_communicator/": {"post": {}},
|
||||
"/update_named_param/": {"post": {}},
|
||||
"/batch_update_named_params/": {"post": {}},
|
||||
"/set_lora_adapter/": {"post": {}},
|
||||
"/clear_lora_adapter/": {"post": {}},
|
||||
"/http_update_weights/": {"post": {}},
|
||||
"/v1/load_lora_adapter": {"post": {}},
|
||||
}
|
||||
}
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.return_value.raise_for_status = lambda: None
|
||||
mock_get.return_value.json = lambda: spec
|
||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||
|
||||
assert caps.probed is True
|
||||
assert caps.nccl is True
|
||||
assert caps.lora_axolotl is True
|
||||
assert caps.lora_filesystem is True
|
||||
assert caps.http_full is True
|
||||
|
||||
def test_trl_vllm_serve_nccl_only(self):
|
||||
"""``trl vllm-serve`` exposes NCCL routes but not LoRA filesystem."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
spec = {
|
||||
"paths": {
|
||||
"/init_communicator/": {"post": {}},
|
||||
"/update_named_param/": {"post": {}},
|
||||
"/batch_update_named_params/": {"post": {}},
|
||||
"/close_communicator/": {"post": {}},
|
||||
"/generate/": {"post": {}},
|
||||
}
|
||||
}
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.return_value.raise_for_status = lambda: None
|
||||
mock_get.return_value.json = lambda: spec
|
||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||
|
||||
assert caps.probed is True
|
||||
assert caps.nccl is True
|
||||
assert caps.lora_filesystem is False
|
||||
assert caps.lora_axolotl is False
|
||||
assert caps.http_full is False
|
||||
|
||||
def test_unreachable_server_records_error(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.side_effect = ConnectionError("Connection refused")
|
||||
caps = probe_vllm_weight_sync("http://localhost:9999")
|
||||
|
||||
assert caps.probed is False
|
||||
assert caps.probe_error is not None
|
||||
assert "ConnectionError" in caps.probe_error
|
||||
assert caps.nccl is False
|
||||
assert caps.lora_filesystem is False
|
||||
|
||||
|
||||
class TestPluginWeightSyncEnforcement(unittest.TestCase):
|
||||
"""End-to-end test of post_trainer_create's transport-selection branch.
|
||||
|
||||
The plugin used to silently no-op weight sync when ``vllm_lora_sync: false``,
|
||||
leaving the trainer learning in isolation while vLLM kept serving the
|
||||
unmodified base model. After the fix:
|
||||
|
||||
- LoRA + LoRA-loading endpoint → installs filesystem LoRA sync
|
||||
- LoRA + only NCCL endpoint → uses NCCL broadcast
|
||||
- Full FT + NCCL endpoint → uses NCCL broadcast (standard TRL flow)
|
||||
- Full FT + HTTP endpoint → raises NotImplementedError (step 3)
|
||||
- No usable transport → raises ValueError with a precise diagnosis
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _fake_cfg(adapter, vllm_lora_sync):
|
||||
class FakeTRL:
|
||||
pass
|
||||
|
||||
class FakeCfg:
|
||||
pass
|
||||
|
||||
trl = FakeTRL()
|
||||
trl.vllm_lora_sync = vllm_lora_sync
|
||||
trl.vllm_server_host = "127.0.0.1"
|
||||
trl.vllm_server_port = 8000
|
||||
|
||||
cfg = FakeCfg()
|
||||
cfg.nemo_gym_enabled = True
|
||||
cfg.nemo_gym_model_name = None
|
||||
cfg.base_model = "test/model"
|
||||
cfg.nemo_gym_verify_timeout = 30
|
||||
cfg.nemo_gym_multi_turn = True
|
||||
cfg.adapter = adapter
|
||||
cfg.trl = trl
|
||||
return cfg
|
||||
|
||||
@staticmethod
|
||||
def _fake_trainer():
|
||||
class FakeVLLMGen:
|
||||
sync_weights = staticmethod(lambda: None)
|
||||
|
||||
class FakeTrainer:
|
||||
vllm_generation = FakeVLLMGen()
|
||||
|
||||
return FakeTrainer()
|
||||
|
||||
@staticmethod
|
||||
def _caps(**kwargs):
|
||||
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
|
||||
|
||||
c = VLLMWeightSyncCapabilities(probed=True)
|
||||
for k, v in kwargs.items():
|
||||
setattr(c, k, v)
|
||||
return c
|
||||
|
||||
def test_lora_with_lora_endpoint_installs_filesystem_sync(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps(lora_filesystem=True)
|
||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
|
||||
trainer = self._fake_trainer()
|
||||
|
||||
with (
|
||||
patch.object(plugin, "_setup_lora_sync") as setup,
|
||||
patch.object(plugin, "_check_lora_endpoint") as check,
|
||||
patch.object(plugin, "_wire_multi_turn") as wire,
|
||||
):
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
setup.assert_called_once()
|
||||
check.assert_called_once()
|
||||
wire.assert_called_once()
|
||||
|
||||
def test_lora_with_no_routes_raises_with_lora_specific_message(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps() # all False, but probed
|
||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
msg = str(ctx.exception)
|
||||
assert "no-op trainer" in msg
|
||||
assert "load_lora_adapter" in msg
|
||||
assert "VLLM_ALLOW_RUNTIME_LORA_UPDATING" in msg
|
||||
|
||||
def test_full_finetune_with_nccl_endpoint_uses_nccl(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps(nccl=True)
|
||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
|
||||
with patch.object(plugin, "_wire_multi_turn") as wire:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
wire.assert_called_once()
|
||||
|
||||
def test_full_finetune_with_http_endpoint_not_implemented_yet(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps(http_full=True)
|
||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
with self.assertRaises(NotImplementedError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
assert "HTTP weight sync" in str(ctx.exception)
|
||||
|
||||
def test_full_finetune_with_no_routes_raises_with_full_param_message(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps()
|
||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
msg = str(ctx.exception)
|
||||
assert "no-op trainer" in msg
|
||||
assert "init_communicator" in msg
|
||||
assert "http_update_weights" in msg
|
||||
|
||||
def test_unprobed_caps_raises_with_probe_failure_message(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
# Plugin._vllm_caps left as default-None: the post_trainer_create
|
||||
# branch falls back to a fresh VLLMWeightSyncCapabilities() with
|
||||
# probed=False, so the error path should mention probing.
|
||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
|
||||
trainer = self._fake_trainer()
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
assert "could not probe" in str(ctx.exception)
|
||||
|
||||
|
||||
class TestNemoGymE2E(unittest.TestCase):
|
||||
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user