From 69f165b39bc11e719b9b1c1ea9753f0e374392a2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 13 Apr 2026 18:29:45 +0000 Subject: [PATCH] probe vLLM weight-sync routes and select transport per server MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The plugin used to unconditionally monkey-patch VLLMClient.init_communicator to a no-op AND silently no-op sync_weights when vllm_lora_sync was off. Combined, this turned the trainer into a functional no-op whenever (a) the user ran NeMo Gym + LoRA without remembering to set vllm_lora_sync=true or (b) the user ran NeMo Gym + full fine-tune (which had no working sync path under the old code). Replace both patches with: 1. A probe of the configured vLLM server's /openapi.json at pre_model_load. Three transports are recognized: - NCCL (/init_communicator/ + /update_named_param/) — TRL serve and axolotl vllm-serve both expose this - LoRA filesystem (/v1/load_lora_adapter or /set_lora_adapter/) - HTTP base64 full-weight (/http_update_weights/) — axolotl vllm-serve only 2. A pure-logic ``select_weight_sync_transport`` that picks the right one for (server caps × adapter type). 3. ``init_communicator`` is only patched out when the server has no NCCL routes; against TRL/axolotl serve modules it stays live so full-finetune NCCL sync works. 4. ``post_trainer_create`` uses the selection table to install LoRA filesystem sync OR leave the standard NCCL flow alone OR raise NotImplementedError (HTTP — pending) OR raise a precise diagnosis when no transport is viable. No more silent no-op trainers. --- src/axolotl/integrations/nemo_gym/plugin.py | 236 ++++++++++++-- tests/integrations/test_nemo_gym.py | 323 ++++++++++++++++++++ 2 files changed, 533 insertions(+), 26 deletions(-) diff --git a/src/axolotl/integrations/nemo_gym/plugin.py b/src/axolotl/integrations/nemo_gym/plugin.py index 14de684cf..b85e344db 100644 --- a/src/axolotl/integrations/nemo_gym/plugin.py +++ b/src/axolotl/integrations/nemo_gym/plugin.py @@ -19,6 +19,7 @@ Supports two modes: from __future__ import annotations import os +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Union from axolotl.integrations.base import BasePlugin @@ -30,6 +31,107 @@ if TYPE_CHECKING: LOG = get_logger(__name__) +# ---- vLLM weight-sync transport probe ------------------------------------ + + +@dataclass +class VLLMWeightSyncCapabilities: + """What weight-sync routes a vLLM server actually exposes. + + Discovered once at ``pre_model_load`` time by fetching the server's + ``/openapi.json``. Drives the transport-selection table below. + """ + + nccl: bool = False # /init_communicator/ + /update_named_param/ + lora_filesystem: bool = False # /v1/load_lora_adapter (vLLM native) + lora_axolotl: bool = False # /set_lora_adapter/ (axolotl serve_lora extension) + http_full: bool = False # /http_update_weights/ (axolotl serve_lora extension) + probed: bool = False + probe_error: str | None = None + routes: list[str] = field(default_factory=list) + + @property + def any_full_param_sync(self) -> bool: + """True if at least one transport can push full-model weights.""" + return self.nccl or self.http_full + + @property + def any_lora_sync(self) -> bool: + """True if at least one transport can push LoRA adapters.""" + return self.lora_filesystem or self.lora_axolotl or self.nccl + + +def probe_vllm_weight_sync( + base_url: str, timeout: float = 5.0 +) -> VLLMWeightSyncCapabilities: + """Detect which weight-sync routes the configured vLLM server exposes. + + Uses the server's FastAPI ``/openapi.json`` — every weight-sync transport + we care about is mounted as a POST route there. Falls back to all-False + on any error so the caller can still decide what to do (typically: raise + a clear error rather than silently no-op). + """ + import requests + + caps = VLLMWeightSyncCapabilities() + try: + r = requests.get(f"{base_url.rstrip('/')}/openapi.json", timeout=timeout) + r.raise_for_status() + spec = r.json() + routes = sorted((spec.get("paths") or {}).keys()) + caps.routes = routes + caps.nccl = "/init_communicator/" in routes and "/update_named_param/" in routes + caps.lora_filesystem = "/v1/load_lora_adapter" in routes + caps.lora_axolotl = "/set_lora_adapter/" in routes + caps.http_full = "/http_update_weights/" in routes + caps.probed = True + except Exception as exc: + caps.probe_error = f"{type(exc).__name__}: {exc}" + LOG.warning( + "NeMo Gym: failed to probe vLLM /openapi.json at %s — %s. " + "Will fall back to LoRA-only behavior.", + base_url, + caps.probe_error, + ) + return caps + + +def select_weight_sync_transport( + caps: VLLMWeightSyncCapabilities, + *, + has_lora: bool, + vllm_lora_sync_pref: bool, +) -> str: + """Pick the right transport for a (server caps, model type) combo. + + Returns one of: ``"lora_filesystem"``, ``"nccl"``, ``"http_full"``, or + ``"none"``. The caller decides what to do with ``"none"`` (typically: + raise an error explaining the misconfiguration). + + Selection table: + LoRA model + lora endpoint + lora-sync pref → lora_filesystem + LoRA model + lora endpoint → lora_filesystem + LoRA model + nccl endpoint → nccl (broadcast merged adapter) + Full model + nccl endpoint → nccl + Full model + http endpoint → http_full + anything else → none + """ + if has_lora: + if (caps.lora_filesystem or caps.lora_axolotl) and vllm_lora_sync_pref: + return "lora_filesystem" + if caps.lora_filesystem or caps.lora_axolotl: + return "lora_filesystem" + if caps.nccl: + return "nccl" + return "none" + # Full-parameter model + if caps.nccl: + return "nccl" + if caps.http_full: + return "http_full" + return "none" + + class NemoGymPlugin(BasePlugin): """Plugin for NVIDIA NeMo Gym integration with Axolotl. @@ -50,37 +152,69 @@ class NemoGymPlugin(BasePlugin): self._reward_fn = None self._dataset_lookup = None self._agent_servers = {} + self._vllm_caps: VLLMWeightSyncCapabilities | None = None def get_input_args(self): return "axolotl.integrations.nemo_gym.NemoGymArgs" def pre_model_load(self, cfg): - """Apply monkeypatches before trainer creation.""" + """Probe vLLM weight-sync routes and conditionally bypass NCCL init. + + Replaces the previous unconditional ``init_communicator`` monkey-patch + with a probe of the configured vLLM server's ``/openapi.json``. We only + bypass NCCL init when the server we're talking to actually lacks the + ``/init_communicator/`` route (i.e. stock ``vllm serve``); against + TRL/axolotl serve modules that DO expose NCCL routes, we leave the + standard TRL flow alone so full-finetune training can sync weights. + """ if not cfg.nemo_gym_enabled: return - # Always skip NCCL communicator init in NeMo Gym mode. - # NeMo Gym uses its own vLLM server (standard OpenAI API), not the TRL - # colocate/NCCL path. The NCCL init fails with vLLM V1 and standard servers. trl_cfg = getattr(cfg, "trl", None) - if trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server": + if not (trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server"): + return + + host = getattr(trl_cfg, "vllm_server_host", None) or "127.0.0.1" + port = getattr(trl_cfg, "vllm_server_port", None) or 8000 + base_url = f"http://{host}:{port}" + self._vllm_caps = probe_vllm_weight_sync(base_url) + + if self._vllm_caps.probed: + LOG.info( + "NeMo Gym: vLLM weight-sync probe @ %s — nccl=%s lora_native=%s " + "lora_axolotl=%s http_full=%s", + base_url, + self._vllm_caps.nccl, + self._vllm_caps.lora_filesystem, + self._vllm_caps.lora_axolotl, + self._vllm_caps.http_full, + ) + + # Only bypass NCCL init when the server doesn't speak it. If NCCL is + # available we leave VLLMClient.init_communicator alone so the + # standard TRL sync flow can run for full-parameter training. + if not self._vllm_caps.nccl: self._patch_skip_nccl_init() def _patch_skip_nccl_init(self): """Monkeypatch VLLMClient.init_communicator to no-op. - NeMo Gym uses its own vLLM server (standard OpenAI API or custom LoRA - serve script). The NCCL communicator is not needed and fails with both - vLLM V1 engine and standard OpenAI server mode. + Only called when the configured vLLM server doesn't expose + ``/init_communicator/`` (e.g. stock ``vllm serve``). In that case + TRL's standard ``init_communicator`` would 404 inside trainer + construction; we no-op it so the LoRA filesystem path can install + its own sync in ``post_trainer_create``. """ try: from trl.generation.vllm_client import VLLMClient VLLMClient._original_init_communicator = VLLMClient.init_communicator VLLMClient.init_communicator = lambda self, **kwargs: LOG.info( - "Skipping NCCL init_communicator (LoRA sync mode)" + "Skipping NCCL init_communicator (server has no /init_communicator/)" + ) + LOG.info( + "Patched VLLMClient.init_communicator to no-op (server has no NCCL routes)" ) - LOG.info("Patched VLLMClient.init_communicator to no-op for LoRA sync") except Exception as exc: LOG.warning(f"Failed to patch VLLMClient: {exc}") @@ -234,30 +368,80 @@ class NemoGymPlugin(BasePlugin): verify_timeout = cfg.nemo_gym_verify_timeout or 30 multi_turn = cfg.nemo_gym_multi_turn or False - # Handle weight sync. NeMo Gym skips NCCL init, so we need to either: - # - Install LoRA sync (when vllm_lora_sync=True) - # - Or no-op sync_weights (when using standard vLLM server) + # Pick a weight-sync transport based on what the configured vLLM + # server actually exposes (see ``pre_model_load`` probe) and what + # kind of model we're training. The selection table is documented + # in ``select_weight_sync_transport``. trl_cfg = getattr(cfg, "trl", None) if hasattr(trainer, "vllm_generation") and trainer.vllm_generation: vllm_gen = trainer.vllm_generation - if trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False): + adapter = getattr(cfg, "adapter", None) + has_lora = adapter in ("lora", "qlora") + vllm_lora_sync_pref = bool( + trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False) + ) + caps = self._vllm_caps or VLLMWeightSyncCapabilities() + transport = select_weight_sync_transport( + caps, + has_lora=has_lora, + vllm_lora_sync_pref=vllm_lora_sync_pref, + ) + + if transport == "lora_filesystem": self._setup_lora_sync(trainer) - # Verify the vLLM server supports runtime LoRA loading self._check_lora_endpoint(vllm_gen) - else: - # No NCCL, no LoRA sync — skip all weight sync paths - vllm_gen.sync_weights = lambda: LOG.debug( - "Weight sync skipped (NeMo Gym mode)" + LOG.info("NeMo Gym weight sync: LoRA filesystem") + elif transport == "nccl": + # Standard TRL NCCL path. We leave ``VLLMClient.init_communicator`` + # alone (pre_model_load only patched it when the probe found no + # NCCL route) so the trainer's normal weight-sync flow runs. + LOG.info( + "NeMo Gym weight sync: NCCL (server exposes /init_communicator/)" ) - type(vllm_gen).sync_weights = lambda self: LOG.debug( - "Weight sync skipped (NeMo Gym mode)" + elif transport == "http_full": + # Full-parameter HTTP sync — implementation lands in step 3. + # For now, fail loudly so users know the path is detected but + # not yet wired up, instead of silently no-oping like before. + raise NotImplementedError( + "NeMo Gym + full fine-tune + HTTP weight sync is detected " + "but the client-side sync helper is not yet implemented " + "(planned). Use `adapter: lora|qlora` for now, or use a " + "vLLM serve module that exposes /init_communicator/ for " + "NCCL sync." ) - # Also patch the async trainer's internal sync method - if hasattr(trainer, "_maybe_sync_vllm_weights"): - trainer._maybe_sync_vllm_weights = lambda: LOG.debug( - "Async weight sync skipped (NeMo Gym mode)" + else: # transport == "none" + # No viable sync path. Build a precise error so the user knows + # exactly what's missing and how to fix it. + if not caps.probed: + msg = ( + "could not probe the vLLM server's " + f"/openapi.json: {caps.probe_error}. " + "Verify that vLLM is reachable at " + f"{getattr(trl_cfg, 'vllm_server_host', '?')}:" + f"{getattr(trl_cfg, 'vllm_server_port', '?')}." ) - LOG.info("Disabled weight sync (NeMo Gym mode, no LoRA sync)") + elif has_lora: + msg = ( + "the vLLM server has neither NCCL routes " + "(/init_communicator/) nor a LoRA-loading route " + "(/v1/load_lora_adapter or /set_lora_adapter/). " + "Restart vLLM with `--enable-lora --max-lora-rank N " + "VLLM_ALLOW_RUNTIME_LORA_UPDATING=1` for the stock " + "server, or use `axolotl vllm-serve` for the " + "NCCL-capable serve module." + ) + else: + msg = ( + "the vLLM server exposes no full-parameter sync route " + "(/init_communicator/ for NCCL or /http_update_weights/ " + "for HTTP). Use `axolotl vllm-serve` (which has both) " + "or set `adapter: lora|qlora`." + ) + raise ValueError( + f"NeMo Gym: no usable weight-sync transport — {msg} Without " + "weight sync the trainer's gradient updates never reach the " + "rollout policy (functionally a no-op trainer)." + ) if multi_turn: self._wire_multi_turn(cfg, trainer, model_name, verify_timeout) diff --git a/tests/integrations/test_nemo_gym.py b/tests/integrations/test_nemo_gym.py index 7fd53cee0..652f744ca 100644 --- a/tests/integrations/test_nemo_gym.py +++ b/tests/integrations/test_nemo_gym.py @@ -361,6 +361,329 @@ class TestPluginDefaults(unittest.TestCase): assert cfg.dataloader_num_workers == 0 +class TestSelectWeightSyncTransport(unittest.TestCase): + """Pure-logic table tests for ``select_weight_sync_transport``.""" + + def _caps(self, **kwargs): + from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities + + c = VLLMWeightSyncCapabilities(probed=True) + for k, v in kwargs.items(): + setattr(c, k, v) + return c + + def test_lora_with_native_endpoint(self): + from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport + + caps = self._caps(lora_filesystem=True) + assert ( + select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True) + == "lora_filesystem" + ) + + def test_lora_with_axolotl_endpoint(self): + from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport + + caps = self._caps(lora_axolotl=True) + assert ( + select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False) + == "lora_filesystem" + ) + + def test_lora_falls_back_to_nccl_when_no_lora_endpoint(self): + from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport + + caps = self._caps(nccl=True) + assert ( + select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False) + == "nccl" + ) + + def test_full_param_prefers_nccl(self): + from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport + + caps = self._caps(nccl=True, http_full=True) + assert ( + select_weight_sync_transport( + caps, has_lora=False, vllm_lora_sync_pref=False + ) + == "nccl" + ) + + def test_full_param_falls_back_to_http(self): + from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport + + caps = self._caps(http_full=True) + assert ( + select_weight_sync_transport( + caps, has_lora=False, vllm_lora_sync_pref=False + ) + == "http_full" + ) + + def test_full_param_no_routes_returns_none(self): + from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport + + caps = self._caps() # all False + assert ( + select_weight_sync_transport( + caps, has_lora=False, vllm_lora_sync_pref=False + ) + == "none" + ) + + def test_lora_no_routes_returns_none(self): + from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport + + caps = self._caps() + assert ( + select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True) + == "none" + ) + + +class TestProbeVllmWeightSync(unittest.TestCase): + """``probe_vllm_weight_sync`` reads a vLLM ``/openapi.json`` and reports caps.""" + + def test_stock_vllm_with_lora_enabled(self): + """Stock ``vllm serve --enable-lora`` exposes only LoRA endpoints.""" + from unittest.mock import patch + + from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync + + spec = { + "paths": { + "/v1/models": {"get": {}}, + "/v1/load_lora_adapter": {"post": {}}, + "/v1/unload_lora_adapter": {"post": {}}, + "/v1/completions": {"post": {}}, + } + } + with patch("requests.get") as mock_get: + mock_get.return_value.raise_for_status = lambda: None + mock_get.return_value.json = lambda: spec + caps = probe_vllm_weight_sync("http://localhost:8000") + + assert caps.probed is True + assert caps.lora_filesystem is True + assert caps.lora_axolotl is False + assert caps.nccl is False + assert caps.http_full is False + + def test_axolotl_serve_lora_full_capabilities(self): + """``axolotl vllm-serve`` exposes NCCL + LoRA + HTTP full sync.""" + from unittest.mock import patch + + from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync + + spec = { + "paths": { + "/init_communicator/": {"post": {}}, + "/update_named_param/": {"post": {}}, + "/batch_update_named_params/": {"post": {}}, + "/set_lora_adapter/": {"post": {}}, + "/clear_lora_adapter/": {"post": {}}, + "/http_update_weights/": {"post": {}}, + "/v1/load_lora_adapter": {"post": {}}, + } + } + with patch("requests.get") as mock_get: + mock_get.return_value.raise_for_status = lambda: None + mock_get.return_value.json = lambda: spec + caps = probe_vllm_weight_sync("http://localhost:8000") + + assert caps.probed is True + assert caps.nccl is True + assert caps.lora_axolotl is True + assert caps.lora_filesystem is True + assert caps.http_full is True + + def test_trl_vllm_serve_nccl_only(self): + """``trl vllm-serve`` exposes NCCL routes but not LoRA filesystem.""" + from unittest.mock import patch + + from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync + + spec = { + "paths": { + "/init_communicator/": {"post": {}}, + "/update_named_param/": {"post": {}}, + "/batch_update_named_params/": {"post": {}}, + "/close_communicator/": {"post": {}}, + "/generate/": {"post": {}}, + } + } + with patch("requests.get") as mock_get: + mock_get.return_value.raise_for_status = lambda: None + mock_get.return_value.json = lambda: spec + caps = probe_vllm_weight_sync("http://localhost:8000") + + assert caps.probed is True + assert caps.nccl is True + assert caps.lora_filesystem is False + assert caps.lora_axolotl is False + assert caps.http_full is False + + def test_unreachable_server_records_error(self): + from unittest.mock import patch + + from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync + + with patch("requests.get") as mock_get: + mock_get.side_effect = ConnectionError("Connection refused") + caps = probe_vllm_weight_sync("http://localhost:9999") + + assert caps.probed is False + assert caps.probe_error is not None + assert "ConnectionError" in caps.probe_error + assert caps.nccl is False + assert caps.lora_filesystem is False + + +class TestPluginWeightSyncEnforcement(unittest.TestCase): + """End-to-end test of post_trainer_create's transport-selection branch. + + The plugin used to silently no-op weight sync when ``vllm_lora_sync: false``, + leaving the trainer learning in isolation while vLLM kept serving the + unmodified base model. After the fix: + + - LoRA + LoRA-loading endpoint → installs filesystem LoRA sync + - LoRA + only NCCL endpoint → uses NCCL broadcast + - Full FT + NCCL endpoint → uses NCCL broadcast (standard TRL flow) + - Full FT + HTTP endpoint → raises NotImplementedError (step 3) + - No usable transport → raises ValueError with a precise diagnosis + """ + + @staticmethod + def _fake_cfg(adapter, vllm_lora_sync): + class FakeTRL: + pass + + class FakeCfg: + pass + + trl = FakeTRL() + trl.vllm_lora_sync = vllm_lora_sync + trl.vllm_server_host = "127.0.0.1" + trl.vllm_server_port = 8000 + + cfg = FakeCfg() + cfg.nemo_gym_enabled = True + cfg.nemo_gym_model_name = None + cfg.base_model = "test/model" + cfg.nemo_gym_verify_timeout = 30 + cfg.nemo_gym_multi_turn = True + cfg.adapter = adapter + cfg.trl = trl + return cfg + + @staticmethod + def _fake_trainer(): + class FakeVLLMGen: + sync_weights = staticmethod(lambda: None) + + class FakeTrainer: + vllm_generation = FakeVLLMGen() + + return FakeTrainer() + + @staticmethod + def _caps(**kwargs): + from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities + + c = VLLMWeightSyncCapabilities(probed=True) + for k, v in kwargs.items(): + setattr(c, k, v) + return c + + def test_lora_with_lora_endpoint_installs_filesystem_sync(self): + from unittest.mock import patch + + from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin + + plugin = NemoGymPlugin() + plugin._vllm_caps = self._caps(lora_filesystem=True) + cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True) + trainer = self._fake_trainer() + + with ( + patch.object(plugin, "_setup_lora_sync") as setup, + patch.object(plugin, "_check_lora_endpoint") as check, + patch.object(plugin, "_wire_multi_turn") as wire, + ): + plugin.post_trainer_create(cfg, trainer) + setup.assert_called_once() + check.assert_called_once() + wire.assert_called_once() + + def test_lora_with_no_routes_raises_with_lora_specific_message(self): + from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin + + plugin = NemoGymPlugin() + plugin._vllm_caps = self._caps() # all False, but probed + cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=False) + trainer = self._fake_trainer() + + with self.assertRaises(ValueError) as ctx: + plugin.post_trainer_create(cfg, trainer) + msg = str(ctx.exception) + assert "no-op trainer" in msg + assert "load_lora_adapter" in msg + assert "VLLM_ALLOW_RUNTIME_LORA_UPDATING" in msg + + def test_full_finetune_with_nccl_endpoint_uses_nccl(self): + from unittest.mock import patch + + from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin + + plugin = NemoGymPlugin() + plugin._vllm_caps = self._caps(nccl=True) + cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False) + trainer = self._fake_trainer() + + with patch.object(plugin, "_wire_multi_turn") as wire: + plugin.post_trainer_create(cfg, trainer) + wire.assert_called_once() + + def test_full_finetune_with_http_endpoint_not_implemented_yet(self): + from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin + + plugin = NemoGymPlugin() + plugin._vllm_caps = self._caps(http_full=True) + cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False) + trainer = self._fake_trainer() + with self.assertRaises(NotImplementedError) as ctx: + plugin.post_trainer_create(cfg, trainer) + assert "HTTP weight sync" in str(ctx.exception) + + def test_full_finetune_with_no_routes_raises_with_full_param_message(self): + from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin + + plugin = NemoGymPlugin() + plugin._vllm_caps = self._caps() + cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False) + trainer = self._fake_trainer() + with self.assertRaises(ValueError) as ctx: + plugin.post_trainer_create(cfg, trainer) + msg = str(ctx.exception) + assert "no-op trainer" in msg + assert "init_communicator" in msg + assert "http_update_weights" in msg + + def test_unprobed_caps_raises_with_probe_failure_message(self): + from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin + + plugin = NemoGymPlugin() + # Plugin._vllm_caps left as default-None: the post_trainer_create + # branch falls back to a fresh VLLMWeightSyncCapabilities() with + # probed=False, so the error path should mention probing. + cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True) + trainer = self._fake_trainer() + with self.assertRaises(ValueError) as ctx: + plugin.post_trainer_create(cfg, trainer) + assert "could not probe" in str(ctx.exception) + + class TestNemoGymE2E(unittest.TestCase): """End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.