probe vLLM weight-sync routes and select transport per server

The plugin used to unconditionally monkey-patch
VLLMClient.init_communicator to a no-op AND silently no-op
sync_weights when vllm_lora_sync was off. Combined, this turned the
trainer into a functional no-op whenever (a) the user ran NeMo Gym
+ LoRA without remembering to set vllm_lora_sync=true or (b) the
user ran NeMo Gym + full fine-tune (which had no working sync path
under the old code).

Replace both patches with:

1. A probe of the configured vLLM server's /openapi.json at
   pre_model_load. Three transports are recognized:
     - NCCL (/init_communicator/ + /update_named_param/) — TRL serve
       and axolotl vllm-serve both expose this
     - LoRA filesystem (/v1/load_lora_adapter or /set_lora_adapter/)
     - HTTP base64 full-weight (/http_update_weights/) — axolotl
       vllm-serve only

2. A pure-logic ``select_weight_sync_transport`` that picks the
   right one for (server caps × adapter type).

3. ``init_communicator`` is only patched out when the server has no
   NCCL routes; against TRL/axolotl serve modules it stays live so
   full-finetune NCCL sync works.

4. ``post_trainer_create`` uses the selection table to install LoRA
   filesystem sync OR leave the standard NCCL flow alone OR raise
   NotImplementedError (HTTP — pending) OR raise a precise diagnosis
   when no transport is viable. No more silent no-op trainers.
This commit is contained in:
Wing Lian
2026-04-13 18:29:45 +00:00
parent 80a97f192b
commit 69f165b39b
2 changed files with 533 additions and 26 deletions

View File

@@ -19,6 +19,7 @@ Supports two modes:
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Union
from axolotl.integrations.base import BasePlugin
@@ -30,6 +31,107 @@ if TYPE_CHECKING:
LOG = get_logger(__name__)
# ---- vLLM weight-sync transport probe ------------------------------------
@dataclass
class VLLMWeightSyncCapabilities:
"""What weight-sync routes a vLLM server actually exposes.
Discovered once at ``pre_model_load`` time by fetching the server's
``/openapi.json``. Drives the transport-selection table below.
"""
nccl: bool = False # /init_communicator/ + /update_named_param/
lora_filesystem: bool = False # /v1/load_lora_adapter (vLLM native)
lora_axolotl: bool = False # /set_lora_adapter/ (axolotl serve_lora extension)
http_full: bool = False # /http_update_weights/ (axolotl serve_lora extension)
probed: bool = False
probe_error: str | None = None
routes: list[str] = field(default_factory=list)
@property
def any_full_param_sync(self) -> bool:
"""True if at least one transport can push full-model weights."""
return self.nccl or self.http_full
@property
def any_lora_sync(self) -> bool:
"""True if at least one transport can push LoRA adapters."""
return self.lora_filesystem or self.lora_axolotl or self.nccl
def probe_vllm_weight_sync(
base_url: str, timeout: float = 5.0
) -> VLLMWeightSyncCapabilities:
"""Detect which weight-sync routes the configured vLLM server exposes.
Uses the server's FastAPI ``/openapi.json`` — every weight-sync transport
we care about is mounted as a POST route there. Falls back to all-False
on any error so the caller can still decide what to do (typically: raise
a clear error rather than silently no-op).
"""
import requests
caps = VLLMWeightSyncCapabilities()
try:
r = requests.get(f"{base_url.rstrip('/')}/openapi.json", timeout=timeout)
r.raise_for_status()
spec = r.json()
routes = sorted((spec.get("paths") or {}).keys())
caps.routes = routes
caps.nccl = "/init_communicator/" in routes and "/update_named_param/" in routes
caps.lora_filesystem = "/v1/load_lora_adapter" in routes
caps.lora_axolotl = "/set_lora_adapter/" in routes
caps.http_full = "/http_update_weights/" in routes
caps.probed = True
except Exception as exc:
caps.probe_error = f"{type(exc).__name__}: {exc}"
LOG.warning(
"NeMo Gym: failed to probe vLLM /openapi.json at %s%s. "
"Will fall back to LoRA-only behavior.",
base_url,
caps.probe_error,
)
return caps
def select_weight_sync_transport(
caps: VLLMWeightSyncCapabilities,
*,
has_lora: bool,
vllm_lora_sync_pref: bool,
) -> str:
"""Pick the right transport for a (server caps, model type) combo.
Returns one of: ``"lora_filesystem"``, ``"nccl"``, ``"http_full"``, or
``"none"``. The caller decides what to do with ``"none"`` (typically:
raise an error explaining the misconfiguration).
Selection table:
LoRA model + lora endpoint + lora-sync pref → lora_filesystem
LoRA model + lora endpoint → lora_filesystem
LoRA model + nccl endpoint → nccl (broadcast merged adapter)
Full model + nccl endpoint → nccl
Full model + http endpoint → http_full
anything else → none
"""
if has_lora:
if (caps.lora_filesystem or caps.lora_axolotl) and vllm_lora_sync_pref:
return "lora_filesystem"
if caps.lora_filesystem or caps.lora_axolotl:
return "lora_filesystem"
if caps.nccl:
return "nccl"
return "none"
# Full-parameter model
if caps.nccl:
return "nccl"
if caps.http_full:
return "http_full"
return "none"
class NemoGymPlugin(BasePlugin):
"""Plugin for NVIDIA NeMo Gym integration with Axolotl.
@@ -50,37 +152,69 @@ class NemoGymPlugin(BasePlugin):
self._reward_fn = None
self._dataset_lookup = None
self._agent_servers = {}
self._vllm_caps: VLLMWeightSyncCapabilities | None = None
def get_input_args(self):
return "axolotl.integrations.nemo_gym.NemoGymArgs"
def pre_model_load(self, cfg):
"""Apply monkeypatches before trainer creation."""
"""Probe vLLM weight-sync routes and conditionally bypass NCCL init.
Replaces the previous unconditional ``init_communicator`` monkey-patch
with a probe of the configured vLLM server's ``/openapi.json``. We only
bypass NCCL init when the server we're talking to actually lacks the
``/init_communicator/`` route (i.e. stock ``vllm serve``); against
TRL/axolotl serve modules that DO expose NCCL routes, we leave the
standard TRL flow alone so full-finetune training can sync weights.
"""
if not cfg.nemo_gym_enabled:
return
# Always skip NCCL communicator init in NeMo Gym mode.
# NeMo Gym uses its own vLLM server (standard OpenAI API), not the TRL
# colocate/NCCL path. The NCCL init fails with vLLM V1 and standard servers.
trl_cfg = getattr(cfg, "trl", None)
if trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server":
if not (trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server"):
return
host = getattr(trl_cfg, "vllm_server_host", None) or "127.0.0.1"
port = getattr(trl_cfg, "vllm_server_port", None) or 8000
base_url = f"http://{host}:{port}"
self._vllm_caps = probe_vllm_weight_sync(base_url)
if self._vllm_caps.probed:
LOG.info(
"NeMo Gym: vLLM weight-sync probe @ %s — nccl=%s lora_native=%s "
"lora_axolotl=%s http_full=%s",
base_url,
self._vllm_caps.nccl,
self._vllm_caps.lora_filesystem,
self._vllm_caps.lora_axolotl,
self._vllm_caps.http_full,
)
# Only bypass NCCL init when the server doesn't speak it. If NCCL is
# available we leave VLLMClient.init_communicator alone so the
# standard TRL sync flow can run for full-parameter training.
if not self._vllm_caps.nccl:
self._patch_skip_nccl_init()
def _patch_skip_nccl_init(self):
"""Monkeypatch VLLMClient.init_communicator to no-op.
NeMo Gym uses its own vLLM server (standard OpenAI API or custom LoRA
serve script). The NCCL communicator is not needed and fails with both
vLLM V1 engine and standard OpenAI server mode.
Only called when the configured vLLM server doesn't expose
``/init_communicator/`` (e.g. stock ``vllm serve``). In that case
TRL's standard ``init_communicator`` would 404 inside trainer
construction; we no-op it so the LoRA filesystem path can install
its own sync in ``post_trainer_create``.
"""
try:
from trl.generation.vllm_client import VLLMClient
VLLMClient._original_init_communicator = VLLMClient.init_communicator
VLLMClient.init_communicator = lambda self, **kwargs: LOG.info(
"Skipping NCCL init_communicator (LoRA sync mode)"
"Skipping NCCL init_communicator (server has no /init_communicator/)"
)
LOG.info(
"Patched VLLMClient.init_communicator to no-op (server has no NCCL routes)"
)
LOG.info("Patched VLLMClient.init_communicator to no-op for LoRA sync")
except Exception as exc:
LOG.warning(f"Failed to patch VLLMClient: {exc}")
@@ -234,30 +368,80 @@ class NemoGymPlugin(BasePlugin):
verify_timeout = cfg.nemo_gym_verify_timeout or 30
multi_turn = cfg.nemo_gym_multi_turn or False
# Handle weight sync. NeMo Gym skips NCCL init, so we need to either:
# - Install LoRA sync (when vllm_lora_sync=True)
# - Or no-op sync_weights (when using standard vLLM server)
# Pick a weight-sync transport based on what the configured vLLM
# server actually exposes (see ``pre_model_load`` probe) and what
# kind of model we're training. The selection table is documented
# in ``select_weight_sync_transport``.
trl_cfg = getattr(cfg, "trl", None)
if hasattr(trainer, "vllm_generation") and trainer.vllm_generation:
vllm_gen = trainer.vllm_generation
if trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False):
adapter = getattr(cfg, "adapter", None)
has_lora = adapter in ("lora", "qlora")
vllm_lora_sync_pref = bool(
trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False)
)
caps = self._vllm_caps or VLLMWeightSyncCapabilities()
transport = select_weight_sync_transport(
caps,
has_lora=has_lora,
vllm_lora_sync_pref=vllm_lora_sync_pref,
)
if transport == "lora_filesystem":
self._setup_lora_sync(trainer)
# Verify the vLLM server supports runtime LoRA loading
self._check_lora_endpoint(vllm_gen)
else:
# No NCCL, no LoRA sync — skip all weight sync paths
vllm_gen.sync_weights = lambda: LOG.debug(
"Weight sync skipped (NeMo Gym mode)"
LOG.info("NeMo Gym weight sync: LoRA filesystem")
elif transport == "nccl":
# Standard TRL NCCL path. We leave ``VLLMClient.init_communicator``
# alone (pre_model_load only patched it when the probe found no
# NCCL route) so the trainer's normal weight-sync flow runs.
LOG.info(
"NeMo Gym weight sync: NCCL (server exposes /init_communicator/)"
)
type(vllm_gen).sync_weights = lambda self: LOG.debug(
"Weight sync skipped (NeMo Gym mode)"
elif transport == "http_full":
# Full-parameter HTTP sync — implementation lands in step 3.
# For now, fail loudly so users know the path is detected but
# not yet wired up, instead of silently no-oping like before.
raise NotImplementedError(
"NeMo Gym + full fine-tune + HTTP weight sync is detected "
"but the client-side sync helper is not yet implemented "
"(planned). Use `adapter: lora|qlora` for now, or use a "
"vLLM serve module that exposes /init_communicator/ for "
"NCCL sync."
)
# Also patch the async trainer's internal sync method
if hasattr(trainer, "_maybe_sync_vllm_weights"):
trainer._maybe_sync_vllm_weights = lambda: LOG.debug(
"Async weight sync skipped (NeMo Gym mode)"
else: # transport == "none"
# No viable sync path. Build a precise error so the user knows
# exactly what's missing and how to fix it.
if not caps.probed:
msg = (
"could not probe the vLLM server's "
f"/openapi.json: {caps.probe_error}. "
"Verify that vLLM is reachable at "
f"{getattr(trl_cfg, 'vllm_server_host', '?')}:"
f"{getattr(trl_cfg, 'vllm_server_port', '?')}."
)
LOG.info("Disabled weight sync (NeMo Gym mode, no LoRA sync)")
elif has_lora:
msg = (
"the vLLM server has neither NCCL routes "
"(/init_communicator/) nor a LoRA-loading route "
"(/v1/load_lora_adapter or /set_lora_adapter/). "
"Restart vLLM with `--enable-lora --max-lora-rank N "
"VLLM_ALLOW_RUNTIME_LORA_UPDATING=1` for the stock "
"server, or use `axolotl vllm-serve` for the "
"NCCL-capable serve module."
)
else:
msg = (
"the vLLM server exposes no full-parameter sync route "
"(/init_communicator/ for NCCL or /http_update_weights/ "
"for HTTP). Use `axolotl vllm-serve` (which has both) "
"or set `adapter: lora|qlora`."
)
raise ValueError(
f"NeMo Gym: no usable weight-sync transport — {msg} Without "
"weight sync the trainer's gradient updates never reach the "
"rollout policy (functionally a no-op trainer)."
)
if multi_turn:
self._wire_multi_turn(cfg, trainer, model_name, verify_timeout)

View File

@@ -361,6 +361,329 @@ class TestPluginDefaults(unittest.TestCase):
assert cfg.dataloader_num_workers == 0
class TestSelectWeightSyncTransport(unittest.TestCase):
"""Pure-logic table tests for ``select_weight_sync_transport``."""
def _caps(self, **kwargs):
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
c = VLLMWeightSyncCapabilities(probed=True)
for k, v in kwargs.items():
setattr(c, k, v)
return c
def test_lora_with_native_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(lora_filesystem=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
== "lora_filesystem"
)
def test_lora_with_axolotl_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(lora_axolotl=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
== "lora_filesystem"
)
def test_lora_falls_back_to_nccl_when_no_lora_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(nccl=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
== "nccl"
)
def test_full_param_prefers_nccl(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(nccl=True, http_full=True)
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "nccl"
)
def test_full_param_falls_back_to_http(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(http_full=True)
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "http_full"
)
def test_full_param_no_routes_returns_none(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps() # all False
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "none"
)
def test_lora_no_routes_returns_none(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps()
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
== "none"
)
class TestProbeVllmWeightSync(unittest.TestCase):
"""``probe_vllm_weight_sync`` reads a vLLM ``/openapi.json`` and reports caps."""
def test_stock_vllm_with_lora_enabled(self):
"""Stock ``vllm serve --enable-lora`` exposes only LoRA endpoints."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/v1/models": {"get": {}},
"/v1/load_lora_adapter": {"post": {}},
"/v1/unload_lora_adapter": {"post": {}},
"/v1/completions": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.lora_filesystem is True
assert caps.lora_axolotl is False
assert caps.nccl is False
assert caps.http_full is False
def test_axolotl_serve_lora_full_capabilities(self):
"""``axolotl vllm-serve`` exposes NCCL + LoRA + HTTP full sync."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/init_communicator/": {"post": {}},
"/update_named_param/": {"post": {}},
"/batch_update_named_params/": {"post": {}},
"/set_lora_adapter/": {"post": {}},
"/clear_lora_adapter/": {"post": {}},
"/http_update_weights/": {"post": {}},
"/v1/load_lora_adapter": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.nccl is True
assert caps.lora_axolotl is True
assert caps.lora_filesystem is True
assert caps.http_full is True
def test_trl_vllm_serve_nccl_only(self):
"""``trl vllm-serve`` exposes NCCL routes but not LoRA filesystem."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/init_communicator/": {"post": {}},
"/update_named_param/": {"post": {}},
"/batch_update_named_params/": {"post": {}},
"/close_communicator/": {"post": {}},
"/generate/": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.nccl is True
assert caps.lora_filesystem is False
assert caps.lora_axolotl is False
assert caps.http_full is False
def test_unreachable_server_records_error(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
with patch("requests.get") as mock_get:
mock_get.side_effect = ConnectionError("Connection refused")
caps = probe_vllm_weight_sync("http://localhost:9999")
assert caps.probed is False
assert caps.probe_error is not None
assert "ConnectionError" in caps.probe_error
assert caps.nccl is False
assert caps.lora_filesystem is False
class TestPluginWeightSyncEnforcement(unittest.TestCase):
"""End-to-end test of post_trainer_create's transport-selection branch.
The plugin used to silently no-op weight sync when ``vllm_lora_sync: false``,
leaving the trainer learning in isolation while vLLM kept serving the
unmodified base model. After the fix:
- LoRA + LoRA-loading endpoint → installs filesystem LoRA sync
- LoRA + only NCCL endpoint → uses NCCL broadcast
- Full FT + NCCL endpoint → uses NCCL broadcast (standard TRL flow)
- Full FT + HTTP endpoint → raises NotImplementedError (step 3)
- No usable transport → raises ValueError with a precise diagnosis
"""
@staticmethod
def _fake_cfg(adapter, vllm_lora_sync):
class FakeTRL:
pass
class FakeCfg:
pass
trl = FakeTRL()
trl.vllm_lora_sync = vllm_lora_sync
trl.vllm_server_host = "127.0.0.1"
trl.vllm_server_port = 8000
cfg = FakeCfg()
cfg.nemo_gym_enabled = True
cfg.nemo_gym_model_name = None
cfg.base_model = "test/model"
cfg.nemo_gym_verify_timeout = 30
cfg.nemo_gym_multi_turn = True
cfg.adapter = adapter
cfg.trl = trl
return cfg
@staticmethod
def _fake_trainer():
class FakeVLLMGen:
sync_weights = staticmethod(lambda: None)
class FakeTrainer:
vllm_generation = FakeVLLMGen()
return FakeTrainer()
@staticmethod
def _caps(**kwargs):
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
c = VLLMWeightSyncCapabilities(probed=True)
for k, v in kwargs.items():
setattr(c, k, v)
return c
def test_lora_with_lora_endpoint_installs_filesystem_sync(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(lora_filesystem=True)
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
trainer = self._fake_trainer()
with (
patch.object(plugin, "_setup_lora_sync") as setup,
patch.object(plugin, "_check_lora_endpoint") as check,
patch.object(plugin, "_wire_multi_turn") as wire,
):
plugin.post_trainer_create(cfg, trainer)
setup.assert_called_once()
check.assert_called_once()
wire.assert_called_once()
def test_lora_with_no_routes_raises_with_lora_specific_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps() # all False, but probed
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
msg = str(ctx.exception)
assert "no-op trainer" in msg
assert "load_lora_adapter" in msg
assert "VLLM_ALLOW_RUNTIME_LORA_UPDATING" in msg
def test_full_finetune_with_nccl_endpoint_uses_nccl(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(nccl=True)
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with patch.object(plugin, "_wire_multi_turn") as wire:
plugin.post_trainer_create(cfg, trainer)
wire.assert_called_once()
def test_full_finetune_with_http_endpoint_not_implemented_yet(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(http_full=True)
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(NotImplementedError) as ctx:
plugin.post_trainer_create(cfg, trainer)
assert "HTTP weight sync" in str(ctx.exception)
def test_full_finetune_with_no_routes_raises_with_full_param_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps()
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
msg = str(ctx.exception)
assert "no-op trainer" in msg
assert "init_communicator" in msg
assert "http_update_weights" in msg
def test_unprobed_caps_raises_with_probe_failure_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
# Plugin._vllm_caps left as default-None: the post_trainer_create
# branch falls back to a fresh VLLMWeightSyncCapabilities() with
# probed=False, so the error path should mention probing.
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
assert "could not probe" in str(ctx.exception)
class TestNemoGymE2E(unittest.TestCase):
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.