From 74b959e035df1e77e02ddb4b3990292bdf583cec Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 25 Mar 2026 11:19:15 -0400 Subject: [PATCH] dispatch scored rollouts to plugins, extend path for external plugins, better handle errors with vllm /reset_prefix_cache (#3549) * dispatch scored rollouts to plugins, extend path for external plugins, better handle errors with vllm /reset_prefix_cache * address PR comments, lint --- .../core/trainers/grpo/async_trainer.py | 50 +++++++++++++++-- src/axolotl/integrations/__init__.py | 3 ++ src/axolotl/integrations/base.py | 54 +++++++++++++++++++ src/axolotl/scripts/vllm_serve_lora.py | 22 +++++--- 4 files changed, 116 insertions(+), 13 deletions(-) diff --git a/src/axolotl/core/trainers/grpo/async_trainer.py b/src/axolotl/core/trainers/grpo/async_trainer.py index 9b6ae2e28..3e541c16d 100644 --- a/src/axolotl/core/trainers/grpo/async_trainer.py +++ b/src/axolotl/core/trainers/grpo/async_trainer.py @@ -1536,6 +1536,29 @@ class AsyncGRPOTrainer(GRPOTrainer): ) -> None: """Called after advantages are computed. Override for replay buffer, re-roll, etc.""" + def _notify_rollouts_scored( + self, + prompts: list[str], + completions: list[str], + rewards: dict[str, list[float]], + advantages: list[float], + ): + """Dispatch on_rollouts_scored to all registered plugins (rank 0 only).""" + if not self.accelerator.is_main_process: + return + + from axolotl.integrations.base import PluginManager + + pm = PluginManager.get_instance() + if pm and pm.plugins: + # Try _axolotl_cfg first (set by causal builder), fall back to + # PluginManager's stored cfg (set during register phase). + cfg = getattr(self, "_axolotl_cfg", None) or getattr(pm, "_cfg", None) + if cfg is not None: + pm.on_rollouts_scored( + cfg, self, prompts, completions, rewards, advantages + ) + # ------------------------------------------------------------------ # Main-thread scoring # ------------------------------------------------------------------ @@ -1860,7 +1883,10 @@ class AsyncGRPOTrainer(GRPOTrainer): nanmax(self.accelerator.gather(torch.max(flat_isr))).item() ) - # Log prompt/completion texts + # Log prompt/completion texts. + # NB: gather_object merges per-rank local texts into a full-batch list + # matching rewards_per_func and all_advantages which are already full-batch + # tensors (gathered/computed earlier in this method). Lengths stay aligned. prompts_text = self.processing_class.batch_decode( prompt_ids, skip_special_tokens=True ) @@ -1868,11 +1894,25 @@ class AsyncGRPOTrainer(GRPOTrainer): completion_ids, skip_special_tokens=True ) if gather_object is not None: - self._logs["prompt"].extend(gather_object(prompts_text)) - self._logs["completion"].extend(gather_object(completions_text)) + gathered_prompts = gather_object(prompts_text) + gathered_completions = gather_object(completions_text) + self._logs["prompt"].extend(gathered_prompts) + self._logs["completion"].extend(gathered_completions) + else: + gathered_prompts = prompts_text + gathered_completions = completions_text + rewards_dict = {} for i, name in enumerate(self.reward_func_names): - self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) - self._logs["advantages"].extend(all_advantages.tolist()) + reward_list = rewards_per_func[:, i].tolist() # already full-batch + self._logs["rewards"][name].extend(reward_list) + rewards_dict[name] = reward_list + adv_list = all_advantages.tolist() # already full-batch + self._logs["advantages"].extend(adv_list) + + # Notify plugins of scored rollouts + self._notify_rollouts_scored( + gathered_prompts, gathered_completions, rewards_dict, adv_list + ) # Remove deferred keys for k in list(data.keys()): diff --git a/src/axolotl/integrations/__init__.py b/src/axolotl/integrations/__init__.py index e69de29bb..f77af49c2 100644 --- a/src/axolotl/integrations/__init__.py +++ b/src/axolotl/integrations/__init__.py @@ -0,0 +1,3 @@ +import pkgutil + +__path__ = pkgutil.extend_path(__path__, __name__) diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index c66bc01c6..48260cb4f 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -242,6 +242,30 @@ class BasePlugin: """ return [] + def on_rollouts_scored( + self, + cfg: DictDefault, + trainer, + prompts: list[str], + completions: list[str], + rewards: dict[str, list[float]], + advantages: list[float], + ): + """Called after rollouts are scored during online RL (GRPO/PPO). + + Provides access to the full scored rollout data for logging, trace + storage, or analysis. Called once per scoring step with all samples + from that step. + + Args: + cfg: The axolotl configuration. + trainer: The trainer instance. + prompts: List of prompt texts (one per sample). + completions: List of completion texts (one per sample). + rewards: Dict mapping reward function name to list of reward values. + advantages: List of advantage values (one per sample). + """ + def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): """Performs actions after training is complete. @@ -613,6 +637,36 @@ class PluginManager: for plugin in self.plugins.values(): plugin.post_train(cfg, model) + def on_rollouts_scored( + self, + cfg: DictDefault, + trainer, + prompts: list[str], + completions: list[str], + rewards: dict[str, list[float]], + advantages: list[float], + ): + """Calls the on_rollouts_scored method of all registered plugins. + + Args: + cfg: The configuration for the plugins. + trainer: The trainer instance. + prompts: List of prompt texts. + completions: List of completion texts. + rewards: Dict mapping reward function name to list of rewards. + advantages: List of advantage values. + """ + for plugin in self.plugins.values(): + try: + plugin.on_rollouts_scored( + cfg, trainer, prompts, completions, rewards, advantages + ) + except Exception: + LOG.warning( + f"Plugin {plugin.__class__.__name__}.on_rollouts_scored failed", + exc_info=True, + ) + def post_train_unload(self, cfg: DictDefault): """Calls the post_train_unload method of all registered plugins. diff --git a/src/axolotl/scripts/vllm_serve_lora.py b/src/axolotl/scripts/vllm_serve_lora.py index e292d89f8..2dda0f9bf 100644 --- a/src/axolotl/scripts/vllm_serve_lora.py +++ b/src/axolotl/scripts/vllm_serve_lora.py @@ -114,8 +114,14 @@ def llm_worker( load_inplace=lr.get("load_inplace", False), ) - method = getattr(llm, method_name) - result = method(*args, **kwargs) + try: + method = getattr(llm, method_name) + result = method(*args, **kwargs) + except Exception as exc: + logger.warning("Worker method %s failed: %s", method_name, exc) + if command["type"] == "call": + connection.send({"error": str(exc), "kind": "worker_error"}) + continue if command["type"] == "call": connection.send(result) elif command["type"] == "shutdown": @@ -650,13 +656,13 @@ def main(script_args: ScriptArguments): @app.post("/reset_prefix_cache/") async def reset_prefix_cache(): + # Fire-and-forget: send reset without expecting a reply. + # Using "fire_and_forget" type so workers don't send back a response + # that would sit in the pipe and corrupt the next recv() for + # generate/chat calls. for conn in connections: - conn.send({"type": "call", "method": "reset_prefix_cache"}) - loop = asyncio.get_running_loop() - results = await asyncio.gather( - *(loop.run_in_executor(None, conn.recv) for conn in connections) - ) - return {"message": f"Reset prefix cache: {all(results)}"} + conn.send({"type": "fire_and_forget", "method": "reset_prefix_cache"}) + return {"message": "Reset prefix cache received"} @app.post("/close_communicator/") async def close_communicator():