make _maybe_sync_vllm_weights actually fire in sync mode

Two bugs in ``AsyncGRPOTrainer._maybe_sync_vllm_weights`` plus a
companion bug in the sync-hook patch site that together neutralized
LoRA weight sync entirely whenever ``async_prefetch=False`` was
combined with NeMo Gym's data-producer path:

1. ``_maybe_sync_vllm_weights`` had ``if not async_prefetch: return``
   at the top. The original design assumed sync mode would fall back
   to TRL's stock per-step ``sync_weights`` call inside
   ``_generate_single_turn`` — true for vanilla GRPO but FALSE in
   NeMo Gym multi-turn, where ``NemoGymDataProducer`` calls the agent
   server directly and ``_generate_single_turn`` is never invoked.
   Result: no sync ever happened in NeMo Gym sync mode.

2. ``step % vllm_sync_interval`` would TypeError on the first call if
   ``vllm_sync_interval`` was unset (the default for any config that
   doesn't explicitly set it).

3. The ``_generate_single_turn`` patch installed
   ``vllm_generation.sync_weights = lambda: None`` unconditionally
   for vllm_lora_sync runs. That's correct in async-prefetch mode
   (BG thread can't safely sync) but wrong in sync mode: TRL's
   per-step auto-sync inside ``_generate_single_turn`` was the
   fallback that the early return in (1) was assuming, and the
   no-op patch was killing it.

Fix:
  - Drop the ``not async_prefetch`` early return; ``_maybe_sync_vllm_weights``
    is now the canonical sync trigger and runs in both modes from
    ``_prepare_inputs_with_data_producer`` / ``_prepare_inputs_legacy_async``.
  - Default ``vllm_sync_interval`` to 1 when unset.
  - In the ``_generate_single_turn`` patch, route sync_weights to
    ``_sync_lora_adapter`` in sync mode (and keep the lambda no-op
    in async mode for the BG-thread safety reason).
This commit is contained in:
Wing Lian
2026-04-13 18:30:16 +00:00
parent e993ed5208
commit 7617b951a8
2 changed files with 238 additions and 6 deletions

View File

@@ -216,5 +216,197 @@ class TestValidateQuantPatchRestore(unittest.TestCase):
self.assertIs(_trainer_module.validate_quantization_for_training, original)
class TestVllmLoraSyncPatch(unittest.TestCase):
"""The ``_generate_single_turn`` patch wires sync_weights to the right place.
These tests exercise the patch-installation branch in isolation. They build
a stub trainer with just enough attributes to look like
``AsyncGRPOTrainer`` for the duration of the relevant code path.
Background — there are two correct behaviors and we historically had a bug
where both modes used the same one:
- Async prefetch ON: the BG generation thread can't safely call
sync_weights mid-rollout. We no-op the stock hook and drive sync from
the main thread via ``_maybe_sync_vllm_weights``.
- Async prefetch OFF: TRL's stock ``_generate_single_turn`` already
calls ``sync_weights`` once per step boundary on the main thread. We
wire that hook directly to ``_sync_lora_adapter`` because
``_maybe_sync_vllm_weights`` short-circuits when async is off.
Before the fix, both modes installed ``lambda: None``, so sync mode never
pushed any LoRA adapter to vLLM and the trainer was a no-op.
"""
@staticmethod
def _make_stub_trainer(*, vllm_lora_sync, async_prefetch):
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
class FakeArgs:
pass
args = FakeArgs()
args.vllm_lora_sync = vllm_lora_sync
args.async_prefetch = async_prefetch
class FakeVllmGen:
sync_weights = staticmethod(lambda: None)
model = MagicMock()
# Use object.__new__ so we don't run __init__ (which needs a real
# model, dataset, etc.). We only need the `_generate_single_turn`
# method's patch branch to run, so we set up the minimum state.
trainer = object.__new__(AsyncGRPOTrainer)
trainer.args = args
trainer.use_vllm = True
trainer.vllm_generation = FakeVllmGen()
trainer._patched_sync_weights = False
# Spy on _sync_lora_adapter so we can assert it's the function the
# hook delegates to in sync mode.
trainer._sync_lora_adapter = MagicMock(name="_sync_lora_adapter_spy")
trainer._sync_peft_weights_no_merge = MagicMock(
name="_sync_peft_weights_no_merge_spy"
)
return trainer
@staticmethod
def _run_patch_branch(trainer):
"""Execute just the sync_weights-patching branch in isolation.
We can't easily call the real ``_generate_single_turn`` because it
does a full vLLM generate. Instead we copy the exact branch out of
the source so the test verifies the same logic the trainer runs.
"""
if not getattr(trainer, "_patched_sync_weights", False):
if trainer.use_vllm and hasattr(trainer, "vllm_generation"):
if getattr(trainer.args, "vllm_lora_sync", False):
if getattr(trainer.args, "async_prefetch", False):
trainer.vllm_generation.sync_weights = lambda: None
else:
sync_helper = trainer._sync_lora_adapter
def _lora_filesystem_sync():
sync_helper()
trainer.vllm_generation.sync_weights = _lora_filesystem_sync
trainer._patched_sync_weights = True
def test_sync_mode_with_lora_sync_wires_to_sync_lora_adapter(self):
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
self._run_patch_branch(trainer)
assert trainer._patched_sync_weights is True
# Trigger the patched hook — it must call _sync_lora_adapter.
trainer.vllm_generation.sync_weights()
trainer._sync_lora_adapter.assert_called_once()
def test_async_mode_with_lora_sync_installs_noop_hook(self):
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=True)
self._run_patch_branch(trainer)
assert trainer._patched_sync_weights is True
# Hook must be a no-op so BG-thread generation doesn't fight the
# main-thread optimizer step over the model weights.
trainer.vllm_generation.sync_weights()
trainer._sync_lora_adapter.assert_not_called()
def test_sync_mode_with_lora_sync_does_not_call_during_install(self):
"""Installing the patch should not pre-emptively sync."""
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
self._run_patch_branch(trainer)
# _sync_lora_adapter should only be called when the patched hook
# itself is invoked (e.g., from TRL's _generate_single_turn).
trainer._sync_lora_adapter.assert_not_called()
def test_patch_is_idempotent(self):
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
self._run_patch_branch(trainer)
first_hook = trainer.vllm_generation.sync_weights
# Second call must not re-patch (otherwise we'd lose the original).
self._run_patch_branch(trainer)
assert trainer.vllm_generation.sync_weights is first_hook
class TestMaybeSyncVllmWeightsIntervalDefault(unittest.TestCase):
"""``_maybe_sync_vllm_weights`` must not crash when interval is unset.
Before the fix, ``step % self.args.vllm_sync_interval`` would TypeError
on the very first call when ``vllm_sync_interval`` was ``None`` (which
is the default for any config that doesn't explicitly set it). We now
fall back to interval=1 so unset means "sync every step", matching the
behavior of TRL's own ``_generate_single_turn``.
"""
@staticmethod
def _make_stub_trainer(interval, async_prefetch):
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
class FakeArgs:
pass
args = FakeArgs()
args.async_prefetch = async_prefetch
args.vllm_sync_interval = interval
args.vllm_lora_sync = True
class FakeState:
global_step = 1
trainer = object.__new__(AsyncGRPOTrainer)
trainer.args = args
trainer.use_vllm = True
trainer.state = FakeState()
trainer._last_synced_step = 0
trainer._sync_lora_adapter = MagicMock(name="sync_spy")
return trainer
def test_interval_none_in_async_mode_does_not_crash(self):
trainer = self._make_stub_trainer(interval=None, async_prefetch=True)
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
# Should not raise TypeError — defaults to every-step sync
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
trainer._sync_lora_adapter.assert_called_once()
def test_sync_mode_drives_sync(self):
"""Sync mode must fire ``_sync_lora_adapter`` from ``_maybe_sync_vllm_weights``.
The previous behavior (early return when ``not async_prefetch``)
assumed TRL's stock ``_generate_single_turn`` would handle sync.
That's true for vanilla GRPO but FALSE for NeMo Gym multi-turn
where the data producer bypasses ``_generate_single_turn``
entirely. Without this trigger no sync ever happens and the
trainer becomes a no-op.
"""
trainer = self._make_stub_trainer(interval=1, async_prefetch=False)
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
trainer._sync_lora_adapter.assert_called_once()
def test_async_mode_with_explicit_interval_respects_modulo(self):
trainer = self._make_stub_trainer(interval=4, async_prefetch=True)
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
# global_step=1, interval=4 → 1 % 4 != 0 → no sync
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
trainer._sync_lora_adapter.assert_not_called()
# global_step=4 → 4 % 4 == 0 → sync
trainer.state.global_step = 4
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
trainer._sync_lora_adapter.assert_called_once()
if __name__ == "__main__":
unittest.main()