dispatch scored rollouts to plugins, extend path for external plugins, better handle errors with vllm /reset_prefix_cache (#3549)

* dispatch scored rollouts to plugins, extend path for external plugins, better handle errors with vllm /reset_prefix_cache

* address PR comments, lint
This commit is contained in:
Wing Lian
2026-03-25 11:19:15 -04:00
committed by GitHub
parent b55706b9f6
commit 74b959e035
4 changed files with 116 additions and 13 deletions

View File

@@ -1536,6 +1536,29 @@ class AsyncGRPOTrainer(GRPOTrainer):
) -> None:
"""Called after advantages are computed. Override for replay buffer, re-roll, etc."""
def _notify_rollouts_scored(
self,
prompts: list[str],
completions: list[str],
rewards: dict[str, list[float]],
advantages: list[float],
):
"""Dispatch on_rollouts_scored to all registered plugins (rank 0 only)."""
if not self.accelerator.is_main_process:
return
from axolotl.integrations.base import PluginManager
pm = PluginManager.get_instance()
if pm and pm.plugins:
# Try _axolotl_cfg first (set by causal builder), fall back to
# PluginManager's stored cfg (set during register phase).
cfg = getattr(self, "_axolotl_cfg", None) or getattr(pm, "_cfg", None)
if cfg is not None:
pm.on_rollouts_scored(
cfg, self, prompts, completions, rewards, advantages
)
# ------------------------------------------------------------------
# Main-thread scoring
# ------------------------------------------------------------------
@@ -1860,7 +1883,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
nanmax(self.accelerator.gather(torch.max(flat_isr))).item()
)
# Log prompt/completion texts
# Log prompt/completion texts.
# NB: gather_object merges per-rank local texts into a full-batch list
# matching rewards_per_func and all_advantages which are already full-batch
# tensors (gathered/computed earlier in this method). Lengths stay aligned.
prompts_text = self.processing_class.batch_decode(
prompt_ids, skip_special_tokens=True
)
@@ -1868,11 +1894,25 @@ class AsyncGRPOTrainer(GRPOTrainer):
completion_ids, skip_special_tokens=True
)
if gather_object is not None:
self._logs["prompt"].extend(gather_object(prompts_text))
self._logs["completion"].extend(gather_object(completions_text))
gathered_prompts = gather_object(prompts_text)
gathered_completions = gather_object(completions_text)
self._logs["prompt"].extend(gathered_prompts)
self._logs["completion"].extend(gathered_completions)
else:
gathered_prompts = prompts_text
gathered_completions = completions_text
rewards_dict = {}
for i, name in enumerate(self.reward_func_names):
self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
self._logs["advantages"].extend(all_advantages.tolist())
reward_list = rewards_per_func[:, i].tolist() # already full-batch
self._logs["rewards"][name].extend(reward_list)
rewards_dict[name] = reward_list
adv_list = all_advantages.tolist() # already full-batch
self._logs["advantages"].extend(adv_list)
# Notify plugins of scored rollouts
self._notify_rollouts_scored(
gathered_prompts, gathered_completions, rewards_dict, adv_list
)
# Remove deferred keys
for k in list(data.keys()):

View File

@@ -0,0 +1,3 @@
import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__)

View File

@@ -242,6 +242,30 @@ class BasePlugin:
"""
return []
def on_rollouts_scored(
self,
cfg: DictDefault,
trainer,
prompts: list[str],
completions: list[str],
rewards: dict[str, list[float]],
advantages: list[float],
):
"""Called after rollouts are scored during online RL (GRPO/PPO).
Provides access to the full scored rollout data for logging, trace
storage, or analysis. Called once per scoring step with all samples
from that step.
Args:
cfg: The axolotl configuration.
trainer: The trainer instance.
prompts: List of prompt texts (one per sample).
completions: List of completion texts (one per sample).
rewards: Dict mapping reward function name to list of reward values.
advantages: List of advantage values (one per sample).
"""
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after training is complete.
@@ -613,6 +637,36 @@ class PluginManager:
for plugin in self.plugins.values():
plugin.post_train(cfg, model)
def on_rollouts_scored(
self,
cfg: DictDefault,
trainer,
prompts: list[str],
completions: list[str],
rewards: dict[str, list[float]],
advantages: list[float],
):
"""Calls the on_rollouts_scored method of all registered plugins.
Args:
cfg: The configuration for the plugins.
trainer: The trainer instance.
prompts: List of prompt texts.
completions: List of completion texts.
rewards: Dict mapping reward function name to list of rewards.
advantages: List of advantage values.
"""
for plugin in self.plugins.values():
try:
plugin.on_rollouts_scored(
cfg, trainer, prompts, completions, rewards, advantages
)
except Exception:
LOG.warning(
f"Plugin {plugin.__class__.__name__}.on_rollouts_scored failed",
exc_info=True,
)
def post_train_unload(self, cfg: DictDefault):
"""Calls the post_train_unload method of all registered plugins.

View File

@@ -114,8 +114,14 @@ def llm_worker(
load_inplace=lr.get("load_inplace", False),
)
method = getattr(llm, method_name)
result = method(*args, **kwargs)
try:
method = getattr(llm, method_name)
result = method(*args, **kwargs)
except Exception as exc:
logger.warning("Worker method %s failed: %s", method_name, exc)
if command["type"] == "call":
connection.send({"error": str(exc), "kind": "worker_error"})
continue
if command["type"] == "call":
connection.send(result)
elif command["type"] == "shutdown":
@@ -650,13 +656,13 @@ def main(script_args: ScriptArguments):
@app.post("/reset_prefix_cache/")
async def reset_prefix_cache():
# Fire-and-forget: send reset without expecting a reply.
# Using "fire_and_forget" type so workers don't send back a response
# that would sit in the pipe and corrupt the next recv() for
# generate/chat calls.
for conn in connections:
conn.send({"type": "call", "method": "reset_prefix_cache"})
loop = asyncio.get_running_loop()
results = await asyncio.gather(
*(loop.run_in_executor(None, conn.recv) for conn in connections)
)
return {"message": f"Reset prefix cache: {all(results)}"}
conn.send({"type": "fire_and_forget", "method": "reset_prefix_cache"})
return {"message": "Reset prefix cache received"}
@app.post("/close_communicator/")
async def close_communicator():