probe vLLM weight-sync routes and select transport per server
The plugin used to unconditionally monkey-patch
VLLMClient.init_communicator to a no-op AND silently no-op
sync_weights when vllm_lora_sync was off. Combined, this turned the
trainer into a functional no-op whenever (a) the user ran NeMo Gym
+ LoRA without remembering to set vllm_lora_sync=true or (b) the
user ran NeMo Gym + full fine-tune (which had no working sync path
under the old code).
Replace both patches with:
1. A probe of the configured vLLM server's /openapi.json at
pre_model_load. Three transports are recognized:
- NCCL (/init_communicator/ + /update_named_param/) — TRL serve
and axolotl vllm-serve both expose this
- LoRA filesystem (/v1/load_lora_adapter or /set_lora_adapter/)
- HTTP base64 full-weight (/http_update_weights/) — axolotl
vllm-serve only
2. A pure-logic ``select_weight_sync_transport`` that picks the
right one for (server caps × adapter type).
3. ``init_communicator`` is only patched out when the server has no
NCCL routes; against TRL/axolotl serve modules it stays live so
full-finetune NCCL sync works.
4. ``post_trainer_create`` uses the selection table to install LoRA
filesystem sync OR leave the standard NCCL flow alone OR raise
NotImplementedError (HTTP — pending) OR raise a precise diagnosis
when no transport is viable. No more silent no-op trainers.
This commit is contained in:
@@ -19,6 +19,7 @@ Supports two modes:
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
@@ -30,6 +31,107 @@ if TYPE_CHECKING:
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
# ---- vLLM weight-sync transport probe ------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLLMWeightSyncCapabilities:
|
||||
"""What weight-sync routes a vLLM server actually exposes.
|
||||
|
||||
Discovered once at ``pre_model_load`` time by fetching the server's
|
||||
``/openapi.json``. Drives the transport-selection table below.
|
||||
"""
|
||||
|
||||
nccl: bool = False # /init_communicator/ + /update_named_param/
|
||||
lora_filesystem: bool = False # /v1/load_lora_adapter (vLLM native)
|
||||
lora_axolotl: bool = False # /set_lora_adapter/ (axolotl serve_lora extension)
|
||||
http_full: bool = False # /http_update_weights/ (axolotl serve_lora extension)
|
||||
probed: bool = False
|
||||
probe_error: str | None = None
|
||||
routes: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def any_full_param_sync(self) -> bool:
|
||||
"""True if at least one transport can push full-model weights."""
|
||||
return self.nccl or self.http_full
|
||||
|
||||
@property
|
||||
def any_lora_sync(self) -> bool:
|
||||
"""True if at least one transport can push LoRA adapters."""
|
||||
return self.lora_filesystem or self.lora_axolotl or self.nccl
|
||||
|
||||
|
||||
def probe_vllm_weight_sync(
|
||||
base_url: str, timeout: float = 5.0
|
||||
) -> VLLMWeightSyncCapabilities:
|
||||
"""Detect which weight-sync routes the configured vLLM server exposes.
|
||||
|
||||
Uses the server's FastAPI ``/openapi.json`` — every weight-sync transport
|
||||
we care about is mounted as a POST route there. Falls back to all-False
|
||||
on any error so the caller can still decide what to do (typically: raise
|
||||
a clear error rather than silently no-op).
|
||||
"""
|
||||
import requests
|
||||
|
||||
caps = VLLMWeightSyncCapabilities()
|
||||
try:
|
||||
r = requests.get(f"{base_url.rstrip('/')}/openapi.json", timeout=timeout)
|
||||
r.raise_for_status()
|
||||
spec = r.json()
|
||||
routes = sorted((spec.get("paths") or {}).keys())
|
||||
caps.routes = routes
|
||||
caps.nccl = "/init_communicator/" in routes and "/update_named_param/" in routes
|
||||
caps.lora_filesystem = "/v1/load_lora_adapter" in routes
|
||||
caps.lora_axolotl = "/set_lora_adapter/" in routes
|
||||
caps.http_full = "/http_update_weights/" in routes
|
||||
caps.probed = True
|
||||
except Exception as exc:
|
||||
caps.probe_error = f"{type(exc).__name__}: {exc}"
|
||||
LOG.warning(
|
||||
"NeMo Gym: failed to probe vLLM /openapi.json at %s — %s. "
|
||||
"Will fall back to LoRA-only behavior.",
|
||||
base_url,
|
||||
caps.probe_error,
|
||||
)
|
||||
return caps
|
||||
|
||||
|
||||
def select_weight_sync_transport(
|
||||
caps: VLLMWeightSyncCapabilities,
|
||||
*,
|
||||
has_lora: bool,
|
||||
vllm_lora_sync_pref: bool,
|
||||
) -> str:
|
||||
"""Pick the right transport for a (server caps, model type) combo.
|
||||
|
||||
Returns one of: ``"lora_filesystem"``, ``"nccl"``, ``"http_full"``, or
|
||||
``"none"``. The caller decides what to do with ``"none"`` (typically:
|
||||
raise an error explaining the misconfiguration).
|
||||
|
||||
Selection table:
|
||||
LoRA model + lora endpoint + lora-sync pref → lora_filesystem
|
||||
LoRA model + lora endpoint → lora_filesystem
|
||||
LoRA model + nccl endpoint → nccl (broadcast merged adapter)
|
||||
Full model + nccl endpoint → nccl
|
||||
Full model + http endpoint → http_full
|
||||
anything else → none
|
||||
"""
|
||||
if has_lora:
|
||||
if (caps.lora_filesystem or caps.lora_axolotl) and vllm_lora_sync_pref:
|
||||
return "lora_filesystem"
|
||||
if caps.lora_filesystem or caps.lora_axolotl:
|
||||
return "lora_filesystem"
|
||||
if caps.nccl:
|
||||
return "nccl"
|
||||
return "none"
|
||||
# Full-parameter model
|
||||
if caps.nccl:
|
||||
return "nccl"
|
||||
if caps.http_full:
|
||||
return "http_full"
|
||||
return "none"
|
||||
|
||||
|
||||
class NemoGymPlugin(BasePlugin):
|
||||
"""Plugin for NVIDIA NeMo Gym integration with Axolotl.
|
||||
|
||||
@@ -50,37 +152,69 @@ class NemoGymPlugin(BasePlugin):
|
||||
self._reward_fn = None
|
||||
self._dataset_lookup = None
|
||||
self._agent_servers = {}
|
||||
self._vllm_caps: VLLMWeightSyncCapabilities | None = None
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.nemo_gym.NemoGymArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""Apply monkeypatches before trainer creation."""
|
||||
"""Probe vLLM weight-sync routes and conditionally bypass NCCL init.
|
||||
|
||||
Replaces the previous unconditional ``init_communicator`` monkey-patch
|
||||
with a probe of the configured vLLM server's ``/openapi.json``. We only
|
||||
bypass NCCL init when the server we're talking to actually lacks the
|
||||
``/init_communicator/`` route (i.e. stock ``vllm serve``); against
|
||||
TRL/axolotl serve modules that DO expose NCCL routes, we leave the
|
||||
standard TRL flow alone so full-finetune training can sync weights.
|
||||
"""
|
||||
if not cfg.nemo_gym_enabled:
|
||||
return
|
||||
|
||||
# Always skip NCCL communicator init in NeMo Gym mode.
|
||||
# NeMo Gym uses its own vLLM server (standard OpenAI API), not the TRL
|
||||
# colocate/NCCL path. The NCCL init fails with vLLM V1 and standard servers.
|
||||
trl_cfg = getattr(cfg, "trl", None)
|
||||
if trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server":
|
||||
if not (trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server"):
|
||||
return
|
||||
|
||||
host = getattr(trl_cfg, "vllm_server_host", None) or "127.0.0.1"
|
||||
port = getattr(trl_cfg, "vllm_server_port", None) or 8000
|
||||
base_url = f"http://{host}:{port}"
|
||||
self._vllm_caps = probe_vllm_weight_sync(base_url)
|
||||
|
||||
if self._vllm_caps.probed:
|
||||
LOG.info(
|
||||
"NeMo Gym: vLLM weight-sync probe @ %s — nccl=%s lora_native=%s "
|
||||
"lora_axolotl=%s http_full=%s",
|
||||
base_url,
|
||||
self._vllm_caps.nccl,
|
||||
self._vllm_caps.lora_filesystem,
|
||||
self._vllm_caps.lora_axolotl,
|
||||
self._vllm_caps.http_full,
|
||||
)
|
||||
|
||||
# Only bypass NCCL init when the server doesn't speak it. If NCCL is
|
||||
# available we leave VLLMClient.init_communicator alone so the
|
||||
# standard TRL sync flow can run for full-parameter training.
|
||||
if not self._vllm_caps.nccl:
|
||||
self._patch_skip_nccl_init()
|
||||
|
||||
def _patch_skip_nccl_init(self):
|
||||
"""Monkeypatch VLLMClient.init_communicator to no-op.
|
||||
|
||||
NeMo Gym uses its own vLLM server (standard OpenAI API or custom LoRA
|
||||
serve script). The NCCL communicator is not needed and fails with both
|
||||
vLLM V1 engine and standard OpenAI server mode.
|
||||
Only called when the configured vLLM server doesn't expose
|
||||
``/init_communicator/`` (e.g. stock ``vllm serve``). In that case
|
||||
TRL's standard ``init_communicator`` would 404 inside trainer
|
||||
construction; we no-op it so the LoRA filesystem path can install
|
||||
its own sync in ``post_trainer_create``.
|
||||
"""
|
||||
try:
|
||||
from trl.generation.vllm_client import VLLMClient
|
||||
|
||||
VLLMClient._original_init_communicator = VLLMClient.init_communicator
|
||||
VLLMClient.init_communicator = lambda self, **kwargs: LOG.info(
|
||||
"Skipping NCCL init_communicator (LoRA sync mode)"
|
||||
"Skipping NCCL init_communicator (server has no /init_communicator/)"
|
||||
)
|
||||
LOG.info(
|
||||
"Patched VLLMClient.init_communicator to no-op (server has no NCCL routes)"
|
||||
)
|
||||
LOG.info("Patched VLLMClient.init_communicator to no-op for LoRA sync")
|
||||
except Exception as exc:
|
||||
LOG.warning(f"Failed to patch VLLMClient: {exc}")
|
||||
|
||||
@@ -234,30 +368,80 @@ class NemoGymPlugin(BasePlugin):
|
||||
verify_timeout = cfg.nemo_gym_verify_timeout or 30
|
||||
multi_turn = cfg.nemo_gym_multi_turn or False
|
||||
|
||||
# Handle weight sync. NeMo Gym skips NCCL init, so we need to either:
|
||||
# - Install LoRA sync (when vllm_lora_sync=True)
|
||||
# - Or no-op sync_weights (when using standard vLLM server)
|
||||
# Pick a weight-sync transport based on what the configured vLLM
|
||||
# server actually exposes (see ``pre_model_load`` probe) and what
|
||||
# kind of model we're training. The selection table is documented
|
||||
# in ``select_weight_sync_transport``.
|
||||
trl_cfg = getattr(cfg, "trl", None)
|
||||
if hasattr(trainer, "vllm_generation") and trainer.vllm_generation:
|
||||
vllm_gen = trainer.vllm_generation
|
||||
if trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False):
|
||||
adapter = getattr(cfg, "adapter", None)
|
||||
has_lora = adapter in ("lora", "qlora")
|
||||
vllm_lora_sync_pref = bool(
|
||||
trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False)
|
||||
)
|
||||
caps = self._vllm_caps or VLLMWeightSyncCapabilities()
|
||||
transport = select_weight_sync_transport(
|
||||
caps,
|
||||
has_lora=has_lora,
|
||||
vllm_lora_sync_pref=vllm_lora_sync_pref,
|
||||
)
|
||||
|
||||
if transport == "lora_filesystem":
|
||||
self._setup_lora_sync(trainer)
|
||||
# Verify the vLLM server supports runtime LoRA loading
|
||||
self._check_lora_endpoint(vllm_gen)
|
||||
else:
|
||||
# No NCCL, no LoRA sync — skip all weight sync paths
|
||||
vllm_gen.sync_weights = lambda: LOG.debug(
|
||||
"Weight sync skipped (NeMo Gym mode)"
|
||||
LOG.info("NeMo Gym weight sync: LoRA filesystem")
|
||||
elif transport == "nccl":
|
||||
# Standard TRL NCCL path. We leave ``VLLMClient.init_communicator``
|
||||
# alone (pre_model_load only patched it when the probe found no
|
||||
# NCCL route) so the trainer's normal weight-sync flow runs.
|
||||
LOG.info(
|
||||
"NeMo Gym weight sync: NCCL (server exposes /init_communicator/)"
|
||||
)
|
||||
type(vllm_gen).sync_weights = lambda self: LOG.debug(
|
||||
"Weight sync skipped (NeMo Gym mode)"
|
||||
elif transport == "http_full":
|
||||
# Full-parameter HTTP sync — implementation lands in step 3.
|
||||
# For now, fail loudly so users know the path is detected but
|
||||
# not yet wired up, instead of silently no-oping like before.
|
||||
raise NotImplementedError(
|
||||
"NeMo Gym + full fine-tune + HTTP weight sync is detected "
|
||||
"but the client-side sync helper is not yet implemented "
|
||||
"(planned). Use `adapter: lora|qlora` for now, or use a "
|
||||
"vLLM serve module that exposes /init_communicator/ for "
|
||||
"NCCL sync."
|
||||
)
|
||||
# Also patch the async trainer's internal sync method
|
||||
if hasattr(trainer, "_maybe_sync_vllm_weights"):
|
||||
trainer._maybe_sync_vllm_weights = lambda: LOG.debug(
|
||||
"Async weight sync skipped (NeMo Gym mode)"
|
||||
else: # transport == "none"
|
||||
# No viable sync path. Build a precise error so the user knows
|
||||
# exactly what's missing and how to fix it.
|
||||
if not caps.probed:
|
||||
msg = (
|
||||
"could not probe the vLLM server's "
|
||||
f"/openapi.json: {caps.probe_error}. "
|
||||
"Verify that vLLM is reachable at "
|
||||
f"{getattr(trl_cfg, 'vllm_server_host', '?')}:"
|
||||
f"{getattr(trl_cfg, 'vllm_server_port', '?')}."
|
||||
)
|
||||
LOG.info("Disabled weight sync (NeMo Gym mode, no LoRA sync)")
|
||||
elif has_lora:
|
||||
msg = (
|
||||
"the vLLM server has neither NCCL routes "
|
||||
"(/init_communicator/) nor a LoRA-loading route "
|
||||
"(/v1/load_lora_adapter or /set_lora_adapter/). "
|
||||
"Restart vLLM with `--enable-lora --max-lora-rank N "
|
||||
"VLLM_ALLOW_RUNTIME_LORA_UPDATING=1` for the stock "
|
||||
"server, or use `axolotl vllm-serve` for the "
|
||||
"NCCL-capable serve module."
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
"the vLLM server exposes no full-parameter sync route "
|
||||
"(/init_communicator/ for NCCL or /http_update_weights/ "
|
||||
"for HTTP). Use `axolotl vllm-serve` (which has both) "
|
||||
"or set `adapter: lora|qlora`."
|
||||
)
|
||||
raise ValueError(
|
||||
f"NeMo Gym: no usable weight-sync transport — {msg} Without "
|
||||
"weight sync the trainer's gradient updates never reach the "
|
||||
"rollout policy (functionally a no-op trainer)."
|
||||
)
|
||||
|
||||
if multi_turn:
|
||||
self._wire_multi_turn(cfg, trainer, model_name, verify_timeout)
|
||||
|
||||
@@ -361,6 +361,329 @@ class TestPluginDefaults(unittest.TestCase):
|
||||
assert cfg.dataloader_num_workers == 0
|
||||
|
||||
|
||||
class TestSelectWeightSyncTransport(unittest.TestCase):
|
||||
"""Pure-logic table tests for ``select_weight_sync_transport``."""
|
||||
|
||||
def _caps(self, **kwargs):
|
||||
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
|
||||
|
||||
c = VLLMWeightSyncCapabilities(probed=True)
|
||||
for k, v in kwargs.items():
|
||||
setattr(c, k, v)
|
||||
return c
|
||||
|
||||
def test_lora_with_native_endpoint(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(lora_filesystem=True)
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
|
||||
== "lora_filesystem"
|
||||
)
|
||||
|
||||
def test_lora_with_axolotl_endpoint(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(lora_axolotl=True)
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
|
||||
== "lora_filesystem"
|
||||
)
|
||||
|
||||
def test_lora_falls_back_to_nccl_when_no_lora_endpoint(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(nccl=True)
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
|
||||
== "nccl"
|
||||
)
|
||||
|
||||
def test_full_param_prefers_nccl(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(nccl=True, http_full=True)
|
||||
assert (
|
||||
select_weight_sync_transport(
|
||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||
)
|
||||
== "nccl"
|
||||
)
|
||||
|
||||
def test_full_param_falls_back_to_http(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(http_full=True)
|
||||
assert (
|
||||
select_weight_sync_transport(
|
||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||
)
|
||||
== "http_full"
|
||||
)
|
||||
|
||||
def test_full_param_no_routes_returns_none(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps() # all False
|
||||
assert (
|
||||
select_weight_sync_transport(
|
||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||
)
|
||||
== "none"
|
||||
)
|
||||
|
||||
def test_lora_no_routes_returns_none(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps()
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
|
||||
== "none"
|
||||
)
|
||||
|
||||
|
||||
class TestProbeVllmWeightSync(unittest.TestCase):
|
||||
"""``probe_vllm_weight_sync`` reads a vLLM ``/openapi.json`` and reports caps."""
|
||||
|
||||
def test_stock_vllm_with_lora_enabled(self):
|
||||
"""Stock ``vllm serve --enable-lora`` exposes only LoRA endpoints."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
spec = {
|
||||
"paths": {
|
||||
"/v1/models": {"get": {}},
|
||||
"/v1/load_lora_adapter": {"post": {}},
|
||||
"/v1/unload_lora_adapter": {"post": {}},
|
||||
"/v1/completions": {"post": {}},
|
||||
}
|
||||
}
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.return_value.raise_for_status = lambda: None
|
||||
mock_get.return_value.json = lambda: spec
|
||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||
|
||||
assert caps.probed is True
|
||||
assert caps.lora_filesystem is True
|
||||
assert caps.lora_axolotl is False
|
||||
assert caps.nccl is False
|
||||
assert caps.http_full is False
|
||||
|
||||
def test_axolotl_serve_lora_full_capabilities(self):
|
||||
"""``axolotl vllm-serve`` exposes NCCL + LoRA + HTTP full sync."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
spec = {
|
||||
"paths": {
|
||||
"/init_communicator/": {"post": {}},
|
||||
"/update_named_param/": {"post": {}},
|
||||
"/batch_update_named_params/": {"post": {}},
|
||||
"/set_lora_adapter/": {"post": {}},
|
||||
"/clear_lora_adapter/": {"post": {}},
|
||||
"/http_update_weights/": {"post": {}},
|
||||
"/v1/load_lora_adapter": {"post": {}},
|
||||
}
|
||||
}
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.return_value.raise_for_status = lambda: None
|
||||
mock_get.return_value.json = lambda: spec
|
||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||
|
||||
assert caps.probed is True
|
||||
assert caps.nccl is True
|
||||
assert caps.lora_axolotl is True
|
||||
assert caps.lora_filesystem is True
|
||||
assert caps.http_full is True
|
||||
|
||||
def test_trl_vllm_serve_nccl_only(self):
|
||||
"""``trl vllm-serve`` exposes NCCL routes but not LoRA filesystem."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
spec = {
|
||||
"paths": {
|
||||
"/init_communicator/": {"post": {}},
|
||||
"/update_named_param/": {"post": {}},
|
||||
"/batch_update_named_params/": {"post": {}},
|
||||
"/close_communicator/": {"post": {}},
|
||||
"/generate/": {"post": {}},
|
||||
}
|
||||
}
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.return_value.raise_for_status = lambda: None
|
||||
mock_get.return_value.json = lambda: spec
|
||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||
|
||||
assert caps.probed is True
|
||||
assert caps.nccl is True
|
||||
assert caps.lora_filesystem is False
|
||||
assert caps.lora_axolotl is False
|
||||
assert caps.http_full is False
|
||||
|
||||
def test_unreachable_server_records_error(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.side_effect = ConnectionError("Connection refused")
|
||||
caps = probe_vllm_weight_sync("http://localhost:9999")
|
||||
|
||||
assert caps.probed is False
|
||||
assert caps.probe_error is not None
|
||||
assert "ConnectionError" in caps.probe_error
|
||||
assert caps.nccl is False
|
||||
assert caps.lora_filesystem is False
|
||||
|
||||
|
||||
class TestPluginWeightSyncEnforcement(unittest.TestCase):
|
||||
"""End-to-end test of post_trainer_create's transport-selection branch.
|
||||
|
||||
The plugin used to silently no-op weight sync when ``vllm_lora_sync: false``,
|
||||
leaving the trainer learning in isolation while vLLM kept serving the
|
||||
unmodified base model. After the fix:
|
||||
|
||||
- LoRA + LoRA-loading endpoint → installs filesystem LoRA sync
|
||||
- LoRA + only NCCL endpoint → uses NCCL broadcast
|
||||
- Full FT + NCCL endpoint → uses NCCL broadcast (standard TRL flow)
|
||||
- Full FT + HTTP endpoint → raises NotImplementedError (step 3)
|
||||
- No usable transport → raises ValueError with a precise diagnosis
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _fake_cfg(adapter, vllm_lora_sync):
|
||||
class FakeTRL:
|
||||
pass
|
||||
|
||||
class FakeCfg:
|
||||
pass
|
||||
|
||||
trl = FakeTRL()
|
||||
trl.vllm_lora_sync = vllm_lora_sync
|
||||
trl.vllm_server_host = "127.0.0.1"
|
||||
trl.vllm_server_port = 8000
|
||||
|
||||
cfg = FakeCfg()
|
||||
cfg.nemo_gym_enabled = True
|
||||
cfg.nemo_gym_model_name = None
|
||||
cfg.base_model = "test/model"
|
||||
cfg.nemo_gym_verify_timeout = 30
|
||||
cfg.nemo_gym_multi_turn = True
|
||||
cfg.adapter = adapter
|
||||
cfg.trl = trl
|
||||
return cfg
|
||||
|
||||
@staticmethod
|
||||
def _fake_trainer():
|
||||
class FakeVLLMGen:
|
||||
sync_weights = staticmethod(lambda: None)
|
||||
|
||||
class FakeTrainer:
|
||||
vllm_generation = FakeVLLMGen()
|
||||
|
||||
return FakeTrainer()
|
||||
|
||||
@staticmethod
|
||||
def _caps(**kwargs):
|
||||
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
|
||||
|
||||
c = VLLMWeightSyncCapabilities(probed=True)
|
||||
for k, v in kwargs.items():
|
||||
setattr(c, k, v)
|
||||
return c
|
||||
|
||||
def test_lora_with_lora_endpoint_installs_filesystem_sync(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps(lora_filesystem=True)
|
||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
|
||||
trainer = self._fake_trainer()
|
||||
|
||||
with (
|
||||
patch.object(plugin, "_setup_lora_sync") as setup,
|
||||
patch.object(plugin, "_check_lora_endpoint") as check,
|
||||
patch.object(plugin, "_wire_multi_turn") as wire,
|
||||
):
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
setup.assert_called_once()
|
||||
check.assert_called_once()
|
||||
wire.assert_called_once()
|
||||
|
||||
def test_lora_with_no_routes_raises_with_lora_specific_message(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps() # all False, but probed
|
||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
msg = str(ctx.exception)
|
||||
assert "no-op trainer" in msg
|
||||
assert "load_lora_adapter" in msg
|
||||
assert "VLLM_ALLOW_RUNTIME_LORA_UPDATING" in msg
|
||||
|
||||
def test_full_finetune_with_nccl_endpoint_uses_nccl(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps(nccl=True)
|
||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
|
||||
with patch.object(plugin, "_wire_multi_turn") as wire:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
wire.assert_called_once()
|
||||
|
||||
def test_full_finetune_with_http_endpoint_not_implemented_yet(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps(http_full=True)
|
||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
with self.assertRaises(NotImplementedError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
assert "HTTP weight sync" in str(ctx.exception)
|
||||
|
||||
def test_full_finetune_with_no_routes_raises_with_full_param_message(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps()
|
||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
msg = str(ctx.exception)
|
||||
assert "no-op trainer" in msg
|
||||
assert "init_communicator" in msg
|
||||
assert "http_update_weights" in msg
|
||||
|
||||
def test_unprobed_caps_raises_with_probe_failure_message(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
# Plugin._vllm_caps left as default-None: the post_trainer_create
|
||||
# branch falls back to a fresh VLLMWeightSyncCapabilities() with
|
||||
# probed=False, so the error path should mention probing.
|
||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
|
||||
trainer = self._fake_trainer()
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
assert "could not probe" in str(ctx.exception)
|
||||
|
||||
|
||||
class TestNemoGymE2E(unittest.TestCase):
|
||||
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user