dispatch scored rollouts to plugins, extend path for external plugins, better handle errors with vllm /reset_prefix_cache (#3549)
* dispatch scored rollouts to plugins, extend path for external plugins, better handle errors with vllm /reset_prefix_cache * address PR comments, lint
This commit is contained in:
@@ -1536,6 +1536,29 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Called after advantages are computed. Override for replay buffer, re-roll, etc."""
|
"""Called after advantages are computed. Override for replay buffer, re-roll, etc."""
|
||||||
|
|
||||||
|
def _notify_rollouts_scored(
|
||||||
|
self,
|
||||||
|
prompts: list[str],
|
||||||
|
completions: list[str],
|
||||||
|
rewards: dict[str, list[float]],
|
||||||
|
advantages: list[float],
|
||||||
|
):
|
||||||
|
"""Dispatch on_rollouts_scored to all registered plugins (rank 0 only)."""
|
||||||
|
if not self.accelerator.is_main_process:
|
||||||
|
return
|
||||||
|
|
||||||
|
from axolotl.integrations.base import PluginManager
|
||||||
|
|
||||||
|
pm = PluginManager.get_instance()
|
||||||
|
if pm and pm.plugins:
|
||||||
|
# Try _axolotl_cfg first (set by causal builder), fall back to
|
||||||
|
# PluginManager's stored cfg (set during register phase).
|
||||||
|
cfg = getattr(self, "_axolotl_cfg", None) or getattr(pm, "_cfg", None)
|
||||||
|
if cfg is not None:
|
||||||
|
pm.on_rollouts_scored(
|
||||||
|
cfg, self, prompts, completions, rewards, advantages
|
||||||
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Main-thread scoring
|
# Main-thread scoring
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -1860,7 +1883,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
nanmax(self.accelerator.gather(torch.max(flat_isr))).item()
|
nanmax(self.accelerator.gather(torch.max(flat_isr))).item()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log prompt/completion texts
|
# Log prompt/completion texts.
|
||||||
|
# NB: gather_object merges per-rank local texts into a full-batch list
|
||||||
|
# matching rewards_per_func and all_advantages which are already full-batch
|
||||||
|
# tensors (gathered/computed earlier in this method). Lengths stay aligned.
|
||||||
prompts_text = self.processing_class.batch_decode(
|
prompts_text = self.processing_class.batch_decode(
|
||||||
prompt_ids, skip_special_tokens=True
|
prompt_ids, skip_special_tokens=True
|
||||||
)
|
)
|
||||||
@@ -1868,11 +1894,25 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
completion_ids, skip_special_tokens=True
|
completion_ids, skip_special_tokens=True
|
||||||
)
|
)
|
||||||
if gather_object is not None:
|
if gather_object is not None:
|
||||||
self._logs["prompt"].extend(gather_object(prompts_text))
|
gathered_prompts = gather_object(prompts_text)
|
||||||
self._logs["completion"].extend(gather_object(completions_text))
|
gathered_completions = gather_object(completions_text)
|
||||||
|
self._logs["prompt"].extend(gathered_prompts)
|
||||||
|
self._logs["completion"].extend(gathered_completions)
|
||||||
|
else:
|
||||||
|
gathered_prompts = prompts_text
|
||||||
|
gathered_completions = completions_text
|
||||||
|
rewards_dict = {}
|
||||||
for i, name in enumerate(self.reward_func_names):
|
for i, name in enumerate(self.reward_func_names):
|
||||||
self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
|
reward_list = rewards_per_func[:, i].tolist() # already full-batch
|
||||||
self._logs["advantages"].extend(all_advantages.tolist())
|
self._logs["rewards"][name].extend(reward_list)
|
||||||
|
rewards_dict[name] = reward_list
|
||||||
|
adv_list = all_advantages.tolist() # already full-batch
|
||||||
|
self._logs["advantages"].extend(adv_list)
|
||||||
|
|
||||||
|
# Notify plugins of scored rollouts
|
||||||
|
self._notify_rollouts_scored(
|
||||||
|
gathered_prompts, gathered_completions, rewards_dict, adv_list
|
||||||
|
)
|
||||||
|
|
||||||
# Remove deferred keys
|
# Remove deferred keys
|
||||||
for k in list(data.keys()):
|
for k in list(data.keys()):
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
import pkgutil
|
||||||
|
|
||||||
|
__path__ = pkgutil.extend_path(__path__, __name__)
|
||||||
|
|||||||
@@ -242,6 +242,30 @@ class BasePlugin:
|
|||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def on_rollouts_scored(
|
||||||
|
self,
|
||||||
|
cfg: DictDefault,
|
||||||
|
trainer,
|
||||||
|
prompts: list[str],
|
||||||
|
completions: list[str],
|
||||||
|
rewards: dict[str, list[float]],
|
||||||
|
advantages: list[float],
|
||||||
|
):
|
||||||
|
"""Called after rollouts are scored during online RL (GRPO/PPO).
|
||||||
|
|
||||||
|
Provides access to the full scored rollout data for logging, trace
|
||||||
|
storage, or analysis. Called once per scoring step with all samples
|
||||||
|
from that step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: The axolotl configuration.
|
||||||
|
trainer: The trainer instance.
|
||||||
|
prompts: List of prompt texts (one per sample).
|
||||||
|
completions: List of completion texts (one per sample).
|
||||||
|
rewards: Dict mapping reward function name to list of reward values.
|
||||||
|
advantages: List of advantage values (one per sample).
|
||||||
|
"""
|
||||||
|
|
||||||
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||||
"""Performs actions after training is complete.
|
"""Performs actions after training is complete.
|
||||||
|
|
||||||
@@ -613,6 +637,36 @@ class PluginManager:
|
|||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
plugin.post_train(cfg, model)
|
plugin.post_train(cfg, model)
|
||||||
|
|
||||||
|
def on_rollouts_scored(
|
||||||
|
self,
|
||||||
|
cfg: DictDefault,
|
||||||
|
trainer,
|
||||||
|
prompts: list[str],
|
||||||
|
completions: list[str],
|
||||||
|
rewards: dict[str, list[float]],
|
||||||
|
advantages: list[float],
|
||||||
|
):
|
||||||
|
"""Calls the on_rollouts_scored method of all registered plugins.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: The configuration for the plugins.
|
||||||
|
trainer: The trainer instance.
|
||||||
|
prompts: List of prompt texts.
|
||||||
|
completions: List of completion texts.
|
||||||
|
rewards: Dict mapping reward function name to list of rewards.
|
||||||
|
advantages: List of advantage values.
|
||||||
|
"""
|
||||||
|
for plugin in self.plugins.values():
|
||||||
|
try:
|
||||||
|
plugin.on_rollouts_scored(
|
||||||
|
cfg, trainer, prompts, completions, rewards, advantages
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning(
|
||||||
|
f"Plugin {plugin.__class__.__name__}.on_rollouts_scored failed",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
def post_train_unload(self, cfg: DictDefault):
|
def post_train_unload(self, cfg: DictDefault):
|
||||||
"""Calls the post_train_unload method of all registered plugins.
|
"""Calls the post_train_unload method of all registered plugins.
|
||||||
|
|
||||||
|
|||||||
@@ -114,8 +114,14 @@ def llm_worker(
|
|||||||
load_inplace=lr.get("load_inplace", False),
|
load_inplace=lr.get("load_inplace", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
method = getattr(llm, method_name)
|
try:
|
||||||
result = method(*args, **kwargs)
|
method = getattr(llm, method_name)
|
||||||
|
result = method(*args, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Worker method %s failed: %s", method_name, exc)
|
||||||
|
if command["type"] == "call":
|
||||||
|
connection.send({"error": str(exc), "kind": "worker_error"})
|
||||||
|
continue
|
||||||
if command["type"] == "call":
|
if command["type"] == "call":
|
||||||
connection.send(result)
|
connection.send(result)
|
||||||
elif command["type"] == "shutdown":
|
elif command["type"] == "shutdown":
|
||||||
@@ -650,13 +656,13 @@ def main(script_args: ScriptArguments):
|
|||||||
|
|
||||||
@app.post("/reset_prefix_cache/")
|
@app.post("/reset_prefix_cache/")
|
||||||
async def reset_prefix_cache():
|
async def reset_prefix_cache():
|
||||||
|
# Fire-and-forget: send reset without expecting a reply.
|
||||||
|
# Using "fire_and_forget" type so workers don't send back a response
|
||||||
|
# that would sit in the pipe and corrupt the next recv() for
|
||||||
|
# generate/chat calls.
|
||||||
for conn in connections:
|
for conn in connections:
|
||||||
conn.send({"type": "call", "method": "reset_prefix_cache"})
|
conn.send({"type": "fire_and_forget", "method": "reset_prefix_cache"})
|
||||||
loop = asyncio.get_running_loop()
|
return {"message": "Reset prefix cache received"}
|
||||||
results = await asyncio.gather(
|
|
||||||
*(loop.run_in_executor(None, conn.recv) for conn in connections)
|
|
||||||
)
|
|
||||||
return {"message": f"Reset prefix cache: {all(results)}"}
|
|
||||||
|
|
||||||
@app.post("/close_communicator/")
|
@app.post("/close_communicator/")
|
||||||
async def close_communicator():
|
async def close_communicator():
|
||||||
|
|||||||
Reference in New Issue
Block a user