diff --git a/src/axolotl/core/trainers/grpo/async_trainer.py b/src/axolotl/core/trainers/grpo/async_trainer.py index 3388687ad..47c3a07ad 100644 --- a/src/axolotl/core/trainers/grpo/async_trainer.py +++ b/src/axolotl/core/trainers/grpo/async_trainer.py @@ -1103,11 +1103,22 @@ class AsyncGRPOTrainer(GRPOTrainer): - vllm_lora_sync: saves adapter to filesystem, vLLM loads natively - PEFT no-merge: computes merged weights as new tensors, NCCL broadcast - Non-PEFT: stock sync_weights via merge_adapter + NCCL + + This is the canonical sync trigger and runs in BOTH async and + synchronous modes from ``_prepare_inputs_with_data_producer`` / + ``_prepare_inputs_legacy_async``. The ``_generate_single_turn`` + patch is a parallel backup for non-data-producer paths (vanilla + GRPO without NeMo Gym), where the data producer is bypassed + entirely and TRL's stock generate-then-sync flow is used instead. """ - if not (self.use_vllm and self.args.async_prefetch): + if not self.use_vllm: return step = self.state.global_step - interval = self.args.vllm_sync_interval + # Default to syncing every step when no interval is configured — + # otherwise ``step % None`` would TypeError, and the previous + # behavior of crashing on the first sync was strictly worse than + # the standard "sync every optimizer step". + interval = self.args.vllm_sync_interval or 1 if step != self._last_synced_step and step % interval == 0: if step == 0: logger.info("Skipping vLLM weight sync at step 0 (no training yet)") @@ -1202,13 +1213,42 @@ class AsyncGRPOTrainer(GRPOTrainer): # Permanently replace vllm_generation.sync_weights with our custom # sync to avoid merge_adapter (fails on FP8 / races with training). - # For LoRA sync mode, make it a no-op here since _maybe_sync_vllm_weights - # handles the sync with proper interval tracking. + # + # The design has two modes that have to be threaded carefully: + # + # - Async prefetch ON: BG generation thread can't safely call + # sync_weights mid-rollout (it races with the trainer's optimizer + # step and can corrupt weights). We no-op the stock sync hook and + # drive sync ourselves from ``_maybe_sync_vllm_weights`` after the + # optimizer step on the main thread. + # + # - Async prefetch OFF (synchronous mode): TRL's stock + # ``_generate_single_turn`` calls ``sync_weights`` once per step + # boundary. There's no BG thread to race with, and + # ``_maybe_sync_vllm_weights`` short-circuits with + # ``if not async_prefetch: return``, so we MUST wire the stock + # hook directly to our LoRA sync helper — otherwise nothing ever + # pushes weights to vLLM and the trainer becomes a no-op (vLLM + # keeps serving the base model, every rollout in every group + # produces identical outputs, advantages are zero, optimizer + # step gets skipped, repeat). if not getattr(self, "_patched_sync_weights", False): if self.use_vllm and hasattr(self, "vllm_generation"): if getattr(self.args, "vllm_lora_sync", False): - # No-op: LoRA sync is driven by _maybe_sync_vllm_weights - self.vllm_generation.sync_weights = lambda: None + if getattr(self.args, "async_prefetch", False): + # Async: drive sync from main thread via + # _maybe_sync_vllm_weights instead. + self.vllm_generation.sync_weights = lambda: None + else: + # Sync mode: TRL's _generate_single_turn already + # calls sync_weights once per step boundary. Wire + # it directly to our LoRA filesystem sync helper. + sync_helper = self._sync_lora_adapter + + def _lora_filesystem_sync(): + sync_helper() + + self.vllm_generation.sync_weights = _lora_filesystem_sync self._patched_sync_weights = True else: from accelerate.utils import is_peft_model diff --git a/tests/core/test_async_grpo.py b/tests/core/test_async_grpo.py index 14c38df29..3a4c188bc 100644 --- a/tests/core/test_async_grpo.py +++ b/tests/core/test_async_grpo.py @@ -216,5 +216,197 @@ class TestValidateQuantPatchRestore(unittest.TestCase): self.assertIs(_trainer_module.validate_quantization_for_training, original) +class TestVllmLoraSyncPatch(unittest.TestCase): + """The ``_generate_single_turn`` patch wires sync_weights to the right place. + + These tests exercise the patch-installation branch in isolation. They build + a stub trainer with just enough attributes to look like + ``AsyncGRPOTrainer`` for the duration of the relevant code path. + + Background — there are two correct behaviors and we historically had a bug + where both modes used the same one: + + - Async prefetch ON: the BG generation thread can't safely call + sync_weights mid-rollout. We no-op the stock hook and drive sync from + the main thread via ``_maybe_sync_vllm_weights``. + - Async prefetch OFF: TRL's stock ``_generate_single_turn`` already + calls ``sync_weights`` once per step boundary on the main thread. We + wire that hook directly to ``_sync_lora_adapter`` because + ``_maybe_sync_vllm_weights`` short-circuits when async is off. + + Before the fix, both modes installed ``lambda: None``, so sync mode never + pushed any LoRA adapter to vLLM and the trainer was a no-op. + """ + + @staticmethod + def _make_stub_trainer(*, vllm_lora_sync, async_prefetch): + from axolotl.core.trainers.grpo.async_trainer import ( + AsyncGRPOTrainer, + ) + + class FakeArgs: + pass + + args = FakeArgs() + args.vllm_lora_sync = vllm_lora_sync + args.async_prefetch = async_prefetch + + class FakeVllmGen: + sync_weights = staticmethod(lambda: None) + model = MagicMock() + + # Use object.__new__ so we don't run __init__ (which needs a real + # model, dataset, etc.). We only need the `_generate_single_turn` + # method's patch branch to run, so we set up the minimum state. + trainer = object.__new__(AsyncGRPOTrainer) + trainer.args = args + trainer.use_vllm = True + trainer.vllm_generation = FakeVllmGen() + trainer._patched_sync_weights = False + # Spy on _sync_lora_adapter so we can assert it's the function the + # hook delegates to in sync mode. + trainer._sync_lora_adapter = MagicMock(name="_sync_lora_adapter_spy") + trainer._sync_peft_weights_no_merge = MagicMock( + name="_sync_peft_weights_no_merge_spy" + ) + return trainer + + @staticmethod + def _run_patch_branch(trainer): + """Execute just the sync_weights-patching branch in isolation. + + We can't easily call the real ``_generate_single_turn`` because it + does a full vLLM generate. Instead we copy the exact branch out of + the source so the test verifies the same logic the trainer runs. + """ + if not getattr(trainer, "_patched_sync_weights", False): + if trainer.use_vllm and hasattr(trainer, "vllm_generation"): + if getattr(trainer.args, "vllm_lora_sync", False): + if getattr(trainer.args, "async_prefetch", False): + trainer.vllm_generation.sync_weights = lambda: None + else: + sync_helper = trainer._sync_lora_adapter + + def _lora_filesystem_sync(): + sync_helper() + + trainer.vllm_generation.sync_weights = _lora_filesystem_sync + trainer._patched_sync_weights = True + + def test_sync_mode_with_lora_sync_wires_to_sync_lora_adapter(self): + trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False) + self._run_patch_branch(trainer) + + assert trainer._patched_sync_weights is True + # Trigger the patched hook — it must call _sync_lora_adapter. + trainer.vllm_generation.sync_weights() + trainer._sync_lora_adapter.assert_called_once() + + def test_async_mode_with_lora_sync_installs_noop_hook(self): + trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=True) + self._run_patch_branch(trainer) + + assert trainer._patched_sync_weights is True + # Hook must be a no-op so BG-thread generation doesn't fight the + # main-thread optimizer step over the model weights. + trainer.vllm_generation.sync_weights() + trainer._sync_lora_adapter.assert_not_called() + + def test_sync_mode_with_lora_sync_does_not_call_during_install(self): + """Installing the patch should not pre-emptively sync.""" + trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False) + self._run_patch_branch(trainer) + # _sync_lora_adapter should only be called when the patched hook + # itself is invoked (e.g., from TRL's _generate_single_turn). + trainer._sync_lora_adapter.assert_not_called() + + def test_patch_is_idempotent(self): + trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False) + self._run_patch_branch(trainer) + first_hook = trainer.vllm_generation.sync_weights + # Second call must not re-patch (otherwise we'd lose the original). + self._run_patch_branch(trainer) + assert trainer.vllm_generation.sync_weights is first_hook + + +class TestMaybeSyncVllmWeightsIntervalDefault(unittest.TestCase): + """``_maybe_sync_vllm_weights`` must not crash when interval is unset. + + Before the fix, ``step % self.args.vllm_sync_interval`` would TypeError + on the very first call when ``vllm_sync_interval`` was ``None`` (which + is the default for any config that doesn't explicitly set it). We now + fall back to interval=1 so unset means "sync every step", matching the + behavior of TRL's own ``_generate_single_turn``. + """ + + @staticmethod + def _make_stub_trainer(interval, async_prefetch): + from axolotl.core.trainers.grpo.async_trainer import ( + AsyncGRPOTrainer, + ) + + class FakeArgs: + pass + + args = FakeArgs() + args.async_prefetch = async_prefetch + args.vllm_sync_interval = interval + args.vllm_lora_sync = True + + class FakeState: + global_step = 1 + + trainer = object.__new__(AsyncGRPOTrainer) + trainer.args = args + trainer.use_vllm = True + trainer.state = FakeState() + trainer._last_synced_step = 0 + trainer._sync_lora_adapter = MagicMock(name="sync_spy") + return trainer + + def test_interval_none_in_async_mode_does_not_crash(self): + trainer = self._make_stub_trainer(interval=None, async_prefetch=True) + from axolotl.core.trainers.grpo.async_trainer import ( + AsyncGRPOTrainer, + ) + + # Should not raise TypeError — defaults to every-step sync + AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer) + trainer._sync_lora_adapter.assert_called_once() + + def test_sync_mode_drives_sync(self): + """Sync mode must fire ``_sync_lora_adapter`` from ``_maybe_sync_vllm_weights``. + + The previous behavior (early return when ``not async_prefetch``) + assumed TRL's stock ``_generate_single_turn`` would handle sync. + That's true for vanilla GRPO but FALSE for NeMo Gym multi-turn + where the data producer bypasses ``_generate_single_turn`` + entirely. Without this trigger no sync ever happens and the + trainer becomes a no-op. + """ + trainer = self._make_stub_trainer(interval=1, async_prefetch=False) + from axolotl.core.trainers.grpo.async_trainer import ( + AsyncGRPOTrainer, + ) + + AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer) + trainer._sync_lora_adapter.assert_called_once() + + def test_async_mode_with_explicit_interval_respects_modulo(self): + trainer = self._make_stub_trainer(interval=4, async_prefetch=True) + from axolotl.core.trainers.grpo.async_trainer import ( + AsyncGRPOTrainer, + ) + + # global_step=1, interval=4 → 1 % 4 != 0 → no sync + AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer) + trainer._sync_lora_adapter.assert_not_called() + + # global_step=4 → 4 % 4 == 0 → sync + trainer.state.global_step = 4 + AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer) + trainer._sync_lora_adapter.assert_called_once() + + if __name__ == "__main__": unittest.main()