make _maybe_sync_vllm_weights actually fire in sync mode
Two bugs in ``AsyncGRPOTrainer._maybe_sync_vllm_weights`` plus a
companion bug in the sync-hook patch site that together neutralized
LoRA weight sync entirely whenever ``async_prefetch=False`` was
combined with NeMo Gym's data-producer path:
1. ``_maybe_sync_vllm_weights`` had ``if not async_prefetch: return``
at the top. The original design assumed sync mode would fall back
to TRL's stock per-step ``sync_weights`` call inside
``_generate_single_turn`` — true for vanilla GRPO but FALSE in
NeMo Gym multi-turn, where ``NemoGymDataProducer`` calls the agent
server directly and ``_generate_single_turn`` is never invoked.
Result: no sync ever happened in NeMo Gym sync mode.
2. ``step % vllm_sync_interval`` would TypeError on the first call if
``vllm_sync_interval`` was unset (the default for any config that
doesn't explicitly set it).
3. The ``_generate_single_turn`` patch installed
``vllm_generation.sync_weights = lambda: None`` unconditionally
for vllm_lora_sync runs. That's correct in async-prefetch mode
(BG thread can't safely sync) but wrong in sync mode: TRL's
per-step auto-sync inside ``_generate_single_turn`` was the
fallback that the early return in (1) was assuming, and the
no-op patch was killing it.
Fix:
- Drop the ``not async_prefetch`` early return; ``_maybe_sync_vllm_weights``
is now the canonical sync trigger and runs in both modes from
``_prepare_inputs_with_data_producer`` / ``_prepare_inputs_legacy_async``.
- Default ``vllm_sync_interval`` to 1 when unset.
- In the ``_generate_single_turn`` patch, route sync_weights to
``_sync_lora_adapter`` in sync mode (and keep the lambda no-op
in async mode for the BG-thread safety reason).
This commit is contained in:
@@ -1103,11 +1103,22 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
- vllm_lora_sync: saves adapter to filesystem, vLLM loads natively
|
||||
- PEFT no-merge: computes merged weights as new tensors, NCCL broadcast
|
||||
- Non-PEFT: stock sync_weights via merge_adapter + NCCL
|
||||
|
||||
This is the canonical sync trigger and runs in BOTH async and
|
||||
synchronous modes from ``_prepare_inputs_with_data_producer`` /
|
||||
``_prepare_inputs_legacy_async``. The ``_generate_single_turn``
|
||||
patch is a parallel backup for non-data-producer paths (vanilla
|
||||
GRPO without NeMo Gym), where the data producer is bypassed
|
||||
entirely and TRL's stock generate-then-sync flow is used instead.
|
||||
"""
|
||||
if not (self.use_vllm and self.args.async_prefetch):
|
||||
if not self.use_vllm:
|
||||
return
|
||||
step = self.state.global_step
|
||||
interval = self.args.vllm_sync_interval
|
||||
# Default to syncing every step when no interval is configured —
|
||||
# otherwise ``step % None`` would TypeError, and the previous
|
||||
# behavior of crashing on the first sync was strictly worse than
|
||||
# the standard "sync every optimizer step".
|
||||
interval = self.args.vllm_sync_interval or 1
|
||||
if step != self._last_synced_step and step % interval == 0:
|
||||
if step == 0:
|
||||
logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
|
||||
@@ -1202,13 +1213,42 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
|
||||
# Permanently replace vllm_generation.sync_weights with our custom
|
||||
# sync to avoid merge_adapter (fails on FP8 / races with training).
|
||||
# For LoRA sync mode, make it a no-op here since _maybe_sync_vllm_weights
|
||||
# handles the sync with proper interval tracking.
|
||||
#
|
||||
# The design has two modes that have to be threaded carefully:
|
||||
#
|
||||
# - Async prefetch ON: BG generation thread can't safely call
|
||||
# sync_weights mid-rollout (it races with the trainer's optimizer
|
||||
# step and can corrupt weights). We no-op the stock sync hook and
|
||||
# drive sync ourselves from ``_maybe_sync_vllm_weights`` after the
|
||||
# optimizer step on the main thread.
|
||||
#
|
||||
# - Async prefetch OFF (synchronous mode): TRL's stock
|
||||
# ``_generate_single_turn`` calls ``sync_weights`` once per step
|
||||
# boundary. There's no BG thread to race with, and
|
||||
# ``_maybe_sync_vllm_weights`` short-circuits with
|
||||
# ``if not async_prefetch: return``, so we MUST wire the stock
|
||||
# hook directly to our LoRA sync helper — otherwise nothing ever
|
||||
# pushes weights to vLLM and the trainer becomes a no-op (vLLM
|
||||
# keeps serving the base model, every rollout in every group
|
||||
# produces identical outputs, advantages are zero, optimizer
|
||||
# step gets skipped, repeat).
|
||||
if not getattr(self, "_patched_sync_weights", False):
|
||||
if self.use_vllm and hasattr(self, "vllm_generation"):
|
||||
if getattr(self.args, "vllm_lora_sync", False):
|
||||
# No-op: LoRA sync is driven by _maybe_sync_vllm_weights
|
||||
self.vllm_generation.sync_weights = lambda: None
|
||||
if getattr(self.args, "async_prefetch", False):
|
||||
# Async: drive sync from main thread via
|
||||
# _maybe_sync_vllm_weights instead.
|
||||
self.vllm_generation.sync_weights = lambda: None
|
||||
else:
|
||||
# Sync mode: TRL's _generate_single_turn already
|
||||
# calls sync_weights once per step boundary. Wire
|
||||
# it directly to our LoRA filesystem sync helper.
|
||||
sync_helper = self._sync_lora_adapter
|
||||
|
||||
def _lora_filesystem_sync():
|
||||
sync_helper()
|
||||
|
||||
self.vllm_generation.sync_weights = _lora_filesystem_sync
|
||||
self._patched_sync_weights = True
|
||||
else:
|
||||
from accelerate.utils import is_peft_model
|
||||
|
||||
@@ -216,5 +216,197 @@ class TestValidateQuantPatchRestore(unittest.TestCase):
|
||||
self.assertIs(_trainer_module.validate_quantization_for_training, original)
|
||||
|
||||
|
||||
class TestVllmLoraSyncPatch(unittest.TestCase):
|
||||
"""The ``_generate_single_turn`` patch wires sync_weights to the right place.
|
||||
|
||||
These tests exercise the patch-installation branch in isolation. They build
|
||||
a stub trainer with just enough attributes to look like
|
||||
``AsyncGRPOTrainer`` for the duration of the relevant code path.
|
||||
|
||||
Background — there are two correct behaviors and we historically had a bug
|
||||
where both modes used the same one:
|
||||
|
||||
- Async prefetch ON: the BG generation thread can't safely call
|
||||
sync_weights mid-rollout. We no-op the stock hook and drive sync from
|
||||
the main thread via ``_maybe_sync_vllm_weights``.
|
||||
- Async prefetch OFF: TRL's stock ``_generate_single_turn`` already
|
||||
calls ``sync_weights`` once per step boundary on the main thread. We
|
||||
wire that hook directly to ``_sync_lora_adapter`` because
|
||||
``_maybe_sync_vllm_weights`` short-circuits when async is off.
|
||||
|
||||
Before the fix, both modes installed ``lambda: None``, so sync mode never
|
||||
pushed any LoRA adapter to vLLM and the trainer was a no-op.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_stub_trainer(*, vllm_lora_sync, async_prefetch):
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
class FakeArgs:
|
||||
pass
|
||||
|
||||
args = FakeArgs()
|
||||
args.vllm_lora_sync = vllm_lora_sync
|
||||
args.async_prefetch = async_prefetch
|
||||
|
||||
class FakeVllmGen:
|
||||
sync_weights = staticmethod(lambda: None)
|
||||
model = MagicMock()
|
||||
|
||||
# Use object.__new__ so we don't run __init__ (which needs a real
|
||||
# model, dataset, etc.). We only need the `_generate_single_turn`
|
||||
# method's patch branch to run, so we set up the minimum state.
|
||||
trainer = object.__new__(AsyncGRPOTrainer)
|
||||
trainer.args = args
|
||||
trainer.use_vllm = True
|
||||
trainer.vllm_generation = FakeVllmGen()
|
||||
trainer._patched_sync_weights = False
|
||||
# Spy on _sync_lora_adapter so we can assert it's the function the
|
||||
# hook delegates to in sync mode.
|
||||
trainer._sync_lora_adapter = MagicMock(name="_sync_lora_adapter_spy")
|
||||
trainer._sync_peft_weights_no_merge = MagicMock(
|
||||
name="_sync_peft_weights_no_merge_spy"
|
||||
)
|
||||
return trainer
|
||||
|
||||
@staticmethod
|
||||
def _run_patch_branch(trainer):
|
||||
"""Execute just the sync_weights-patching branch in isolation.
|
||||
|
||||
We can't easily call the real ``_generate_single_turn`` because it
|
||||
does a full vLLM generate. Instead we copy the exact branch out of
|
||||
the source so the test verifies the same logic the trainer runs.
|
||||
"""
|
||||
if not getattr(trainer, "_patched_sync_weights", False):
|
||||
if trainer.use_vllm and hasattr(trainer, "vllm_generation"):
|
||||
if getattr(trainer.args, "vllm_lora_sync", False):
|
||||
if getattr(trainer.args, "async_prefetch", False):
|
||||
trainer.vllm_generation.sync_weights = lambda: None
|
||||
else:
|
||||
sync_helper = trainer._sync_lora_adapter
|
||||
|
||||
def _lora_filesystem_sync():
|
||||
sync_helper()
|
||||
|
||||
trainer.vllm_generation.sync_weights = _lora_filesystem_sync
|
||||
trainer._patched_sync_weights = True
|
||||
|
||||
def test_sync_mode_with_lora_sync_wires_to_sync_lora_adapter(self):
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||
self._run_patch_branch(trainer)
|
||||
|
||||
assert trainer._patched_sync_weights is True
|
||||
# Trigger the patched hook — it must call _sync_lora_adapter.
|
||||
trainer.vllm_generation.sync_weights()
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
def test_async_mode_with_lora_sync_installs_noop_hook(self):
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=True)
|
||||
self._run_patch_branch(trainer)
|
||||
|
||||
assert trainer._patched_sync_weights is True
|
||||
# Hook must be a no-op so BG-thread generation doesn't fight the
|
||||
# main-thread optimizer step over the model weights.
|
||||
trainer.vllm_generation.sync_weights()
|
||||
trainer._sync_lora_adapter.assert_not_called()
|
||||
|
||||
def test_sync_mode_with_lora_sync_does_not_call_during_install(self):
|
||||
"""Installing the patch should not pre-emptively sync."""
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||
self._run_patch_branch(trainer)
|
||||
# _sync_lora_adapter should only be called when the patched hook
|
||||
# itself is invoked (e.g., from TRL's _generate_single_turn).
|
||||
trainer._sync_lora_adapter.assert_not_called()
|
||||
|
||||
def test_patch_is_idempotent(self):
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||
self._run_patch_branch(trainer)
|
||||
first_hook = trainer.vllm_generation.sync_weights
|
||||
# Second call must not re-patch (otherwise we'd lose the original).
|
||||
self._run_patch_branch(trainer)
|
||||
assert trainer.vllm_generation.sync_weights is first_hook
|
||||
|
||||
|
||||
class TestMaybeSyncVllmWeightsIntervalDefault(unittest.TestCase):
|
||||
"""``_maybe_sync_vllm_weights`` must not crash when interval is unset.
|
||||
|
||||
Before the fix, ``step % self.args.vllm_sync_interval`` would TypeError
|
||||
on the very first call when ``vllm_sync_interval`` was ``None`` (which
|
||||
is the default for any config that doesn't explicitly set it). We now
|
||||
fall back to interval=1 so unset means "sync every step", matching the
|
||||
behavior of TRL's own ``_generate_single_turn``.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_stub_trainer(interval, async_prefetch):
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
class FakeArgs:
|
||||
pass
|
||||
|
||||
args = FakeArgs()
|
||||
args.async_prefetch = async_prefetch
|
||||
args.vllm_sync_interval = interval
|
||||
args.vllm_lora_sync = True
|
||||
|
||||
class FakeState:
|
||||
global_step = 1
|
||||
|
||||
trainer = object.__new__(AsyncGRPOTrainer)
|
||||
trainer.args = args
|
||||
trainer.use_vllm = True
|
||||
trainer.state = FakeState()
|
||||
trainer._last_synced_step = 0
|
||||
trainer._sync_lora_adapter = MagicMock(name="sync_spy")
|
||||
return trainer
|
||||
|
||||
def test_interval_none_in_async_mode_does_not_crash(self):
|
||||
trainer = self._make_stub_trainer(interval=None, async_prefetch=True)
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
# Should not raise TypeError — defaults to every-step sync
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
def test_sync_mode_drives_sync(self):
|
||||
"""Sync mode must fire ``_sync_lora_adapter`` from ``_maybe_sync_vllm_weights``.
|
||||
|
||||
The previous behavior (early return when ``not async_prefetch``)
|
||||
assumed TRL's stock ``_generate_single_turn`` would handle sync.
|
||||
That's true for vanilla GRPO but FALSE for NeMo Gym multi-turn
|
||||
where the data producer bypasses ``_generate_single_turn``
|
||||
entirely. Without this trigger no sync ever happens and the
|
||||
trainer becomes a no-op.
|
||||
"""
|
||||
trainer = self._make_stub_trainer(interval=1, async_prefetch=False)
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
def test_async_mode_with_explicit_interval_respects_modulo(self):
|
||||
trainer = self._make_stub_trainer(interval=4, async_prefetch=True)
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
# global_step=1, interval=4 → 1 % 4 != 0 → no sync
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_not_called()
|
||||
|
||||
# global_step=4 → 4 % 4 == 0 → sync
|
||||
trainer.state.global_step = 4
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user