dispatch scored rollouts to plugins, extend path for external plugins, better handle errors with vllm /reset_prefix_cache (#3549)
* dispatch scored rollouts to plugins, extend path for external plugins, better handle errors with vllm /reset_prefix_cache * address PR comments, lint
This commit is contained in:
@@ -1536,6 +1536,29 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
) -> None:
|
||||
"""Called after advantages are computed. Override for replay buffer, re-roll, etc."""
|
||||
|
||||
def _notify_rollouts_scored(
|
||||
self,
|
||||
prompts: list[str],
|
||||
completions: list[str],
|
||||
rewards: dict[str, list[float]],
|
||||
advantages: list[float],
|
||||
):
|
||||
"""Dispatch on_rollouts_scored to all registered plugins (rank 0 only)."""
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
pm = PluginManager.get_instance()
|
||||
if pm and pm.plugins:
|
||||
# Try _axolotl_cfg first (set by causal builder), fall back to
|
||||
# PluginManager's stored cfg (set during register phase).
|
||||
cfg = getattr(self, "_axolotl_cfg", None) or getattr(pm, "_cfg", None)
|
||||
if cfg is not None:
|
||||
pm.on_rollouts_scored(
|
||||
cfg, self, prompts, completions, rewards, advantages
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main-thread scoring
|
||||
# ------------------------------------------------------------------
|
||||
@@ -1860,7 +1883,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
nanmax(self.accelerator.gather(torch.max(flat_isr))).item()
|
||||
)
|
||||
|
||||
# Log prompt/completion texts
|
||||
# Log prompt/completion texts.
|
||||
# NB: gather_object merges per-rank local texts into a full-batch list
|
||||
# matching rewards_per_func and all_advantages which are already full-batch
|
||||
# tensors (gathered/computed earlier in this method). Lengths stay aligned.
|
||||
prompts_text = self.processing_class.batch_decode(
|
||||
prompt_ids, skip_special_tokens=True
|
||||
)
|
||||
@@ -1868,11 +1894,25 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
completion_ids, skip_special_tokens=True
|
||||
)
|
||||
if gather_object is not None:
|
||||
self._logs["prompt"].extend(gather_object(prompts_text))
|
||||
self._logs["completion"].extend(gather_object(completions_text))
|
||||
gathered_prompts = gather_object(prompts_text)
|
||||
gathered_completions = gather_object(completions_text)
|
||||
self._logs["prompt"].extend(gathered_prompts)
|
||||
self._logs["completion"].extend(gathered_completions)
|
||||
else:
|
||||
gathered_prompts = prompts_text
|
||||
gathered_completions = completions_text
|
||||
rewards_dict = {}
|
||||
for i, name in enumerate(self.reward_func_names):
|
||||
self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
|
||||
self._logs["advantages"].extend(all_advantages.tolist())
|
||||
reward_list = rewards_per_func[:, i].tolist() # already full-batch
|
||||
self._logs["rewards"][name].extend(reward_list)
|
||||
rewards_dict[name] = reward_list
|
||||
adv_list = all_advantages.tolist() # already full-batch
|
||||
self._logs["advantages"].extend(adv_list)
|
||||
|
||||
# Notify plugins of scored rollouts
|
||||
self._notify_rollouts_scored(
|
||||
gathered_prompts, gathered_completions, rewards_dict, adv_list
|
||||
)
|
||||
|
||||
# Remove deferred keys
|
||||
for k in list(data.keys()):
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__)
|
||||
|
||||
@@ -242,6 +242,30 @@ class BasePlugin:
|
||||
"""
|
||||
return []
|
||||
|
||||
def on_rollouts_scored(
|
||||
self,
|
||||
cfg: DictDefault,
|
||||
trainer,
|
||||
prompts: list[str],
|
||||
completions: list[str],
|
||||
rewards: dict[str, list[float]],
|
||||
advantages: list[float],
|
||||
):
|
||||
"""Called after rollouts are scored during online RL (GRPO/PPO).
|
||||
|
||||
Provides access to the full scored rollout data for logging, trace
|
||||
storage, or analysis. Called once per scoring step with all samples
|
||||
from that step.
|
||||
|
||||
Args:
|
||||
cfg: The axolotl configuration.
|
||||
trainer: The trainer instance.
|
||||
prompts: List of prompt texts (one per sample).
|
||||
completions: List of completion texts (one per sample).
|
||||
rewards: Dict mapping reward function name to list of reward values.
|
||||
advantages: List of advantage values (one per sample).
|
||||
"""
|
||||
|
||||
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||
"""Performs actions after training is complete.
|
||||
|
||||
@@ -613,6 +637,36 @@ class PluginManager:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_train(cfg, model)
|
||||
|
||||
def on_rollouts_scored(
|
||||
self,
|
||||
cfg: DictDefault,
|
||||
trainer,
|
||||
prompts: list[str],
|
||||
completions: list[str],
|
||||
rewards: dict[str, list[float]],
|
||||
advantages: list[float],
|
||||
):
|
||||
"""Calls the on_rollouts_scored method of all registered plugins.
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the plugins.
|
||||
trainer: The trainer instance.
|
||||
prompts: List of prompt texts.
|
||||
completions: List of completion texts.
|
||||
rewards: Dict mapping reward function name to list of rewards.
|
||||
advantages: List of advantage values.
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
try:
|
||||
plugin.on_rollouts_scored(
|
||||
cfg, trainer, prompts, completions, rewards, advantages
|
||||
)
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
f"Plugin {plugin.__class__.__name__}.on_rollouts_scored failed",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def post_train_unload(self, cfg: DictDefault):
|
||||
"""Calls the post_train_unload method of all registered plugins.
|
||||
|
||||
|
||||
@@ -114,8 +114,14 @@ def llm_worker(
|
||||
load_inplace=lr.get("load_inplace", False),
|
||||
)
|
||||
|
||||
method = getattr(llm, method_name)
|
||||
result = method(*args, **kwargs)
|
||||
try:
|
||||
method = getattr(llm, method_name)
|
||||
result = method(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("Worker method %s failed: %s", method_name, exc)
|
||||
if command["type"] == "call":
|
||||
connection.send({"error": str(exc), "kind": "worker_error"})
|
||||
continue
|
||||
if command["type"] == "call":
|
||||
connection.send(result)
|
||||
elif command["type"] == "shutdown":
|
||||
@@ -650,13 +656,13 @@ def main(script_args: ScriptArguments):
|
||||
|
||||
@app.post("/reset_prefix_cache/")
|
||||
async def reset_prefix_cache():
|
||||
# Fire-and-forget: send reset without expecting a reply.
|
||||
# Using "fire_and_forget" type so workers don't send back a response
|
||||
# that would sit in the pipe and corrupt the next recv() for
|
||||
# generate/chat calls.
|
||||
for conn in connections:
|
||||
conn.send({"type": "call", "method": "reset_prefix_cache"})
|
||||
loop = asyncio.get_running_loop()
|
||||
results = await asyncio.gather(
|
||||
*(loop.run_in_executor(None, conn.recv) for conn in connections)
|
||||
)
|
||||
return {"message": f"Reset prefix cache: {all(results)}"}
|
||||
conn.send({"type": "fire_and_forget", "method": "reset_prefix_cache"})
|
||||
return {"message": "Reset prefix cache received"}
|
||||
|
||||
@app.post("/close_communicator/")
|
||||
async def close_communicator():
|
||||
|
||||
Reference in New Issue
Block a user