make gpus go brrr
This commit is contained in:
@@ -25,7 +25,7 @@ import concurrent.futures
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
@@ -562,7 +562,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
self._prompt_iter = None
|
||||
self._last_synced_step = -1
|
||||
self._buffered_inputs: list | None = None # override stock attr
|
||||
self._current_train_step_time = 0.0
|
||||
|
||||
# Data producer (the proper architecture for async generation)
|
||||
self.data_producer = None
|
||||
@@ -869,6 +868,24 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
"""Compute rewards for a batch. Override for parallel workers, caching, etc."""
|
||||
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
|
||||
|
||||
def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):
|
||||
"""Launch reward computation in background. Override for parallel dispatch.
|
||||
|
||||
Default: no-op (rewards computed synchronously in _collect_reward_workers).
|
||||
"""
|
||||
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
|
||||
|
||||
def _collect_reward_workers(self, inputs, prompts, completions, completion_ids_list):
|
||||
"""Collect reward results. Override to collect from parallel workers.
|
||||
|
||||
Default: compute rewards synchronously now.
|
||||
"""
|
||||
args = getattr(self, "_pending_reward_args", None)
|
||||
if args is not None:
|
||||
self._pending_reward_args = None
|
||||
return self._compute_rewards_for_batch(*args)
|
||||
return self._compute_rewards_for_batch(inputs, prompts, completions, completion_ids_list)
|
||||
|
||||
def _post_advantage_hook(
|
||||
self,
|
||||
data: dict,
|
||||
@@ -929,6 +946,9 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
forward_kwargs[key] = data[key]
|
||||
num_images = data.get("num_images")
|
||||
|
||||
# --- Launch rewards in parallel with logprobs ---
|
||||
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
|
||||
|
||||
# --- Policy logprobs ---
|
||||
logprob_batch_size = min(batch_size * 4, len(prompt_ids))
|
||||
with disable_gradient_checkpointing(
|
||||
@@ -1013,8 +1033,8 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
|
||||
data["importance_sampling_ratio"] = is_ratio
|
||||
|
||||
# --- Rewards ---
|
||||
rewards_per_func = self._compute_rewards_for_batch(
|
||||
# --- Collect rewards (launched before logprobs, should be done) ---
|
||||
rewards_per_func = self._collect_reward_workers(
|
||||
inputs, prompts, completions, completion_ids_list
|
||||
)
|
||||
|
||||
@@ -1267,6 +1287,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
):
|
||||
num_images = num_images[s_start:s_end]
|
||||
|
||||
# --- Launch rewards in parallel with logprobs ---
|
||||
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
|
||||
|
||||
# --- Policy logprobs for this chunk (GPU, overlaps with BG rewards) ---
|
||||
logprob_batch_size = min(batch_size * 4, chunk_size)
|
||||
with disable_gradient_checkpointing(
|
||||
self.model, self.args.gradient_checkpointing_kwargs
|
||||
@@ -1356,8 +1380,8 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
)
|
||||
data["ref_per_token_logps"][s_start:s_end] = ref_logps
|
||||
|
||||
# --- Rewards ---
|
||||
rewards_per_func = self._compute_rewards_for_batch(
|
||||
# --- Collect rewards (should already be done, ran in parallel with logprobs) ---
|
||||
rewards_per_func = self._collect_reward_workers(
|
||||
inputs, prompts, completions, completion_ids_list
|
||||
)
|
||||
|
||||
@@ -1620,6 +1644,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
|
||||
# Produce a new rollout
|
||||
self._maybe_sync_vllm_weights()
|
||||
|
||||
rollout_dataset = self.data_producer.produce(
|
||||
self.model,
|
||||
self.state.global_step,
|
||||
@@ -1728,7 +1753,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
token_type_ids=inputs.get("token_type_ids"),
|
||||
mm_token_type_ids=inputs.get("mm_token_type_ids"),
|
||||
)
|
||||
|
||||
if self.top_entropy_quantile < 1.0:
|
||||
entropy_mask = self.get_high_entropy_mask(
|
||||
entropies, mask, 1 - self.top_entropy_quantile
|
||||
@@ -1910,12 +1934,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def training_step(self, model, inputs, num_items_in_batch=None):
|
||||
t0 = time.perf_counter()
|
||||
output = super().training_step(model, inputs, num_items_in_batch)
|
||||
self._step += 1
|
||||
t1 = time.perf_counter()
|
||||
self._current_train_step_time += t1 - t0
|
||||
if self._step % self.current_gradient_accumulation_steps == 0:
|
||||
self._metrics["train"]["step_time"].append(self._current_train_step_time)
|
||||
self._current_train_step_time = 0.0
|
||||
return output
|
||||
|
||||
@@ -337,9 +337,16 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
||||
def _compute_rewards_for_batch(
|
||||
self, inputs, prompts, completions, completion_ids_list
|
||||
):
|
||||
"""Dispatch rewards to parallel subprocess workers when possible."""
|
||||
from accelerate.utils import gather
|
||||
"""Dispatch rewards to parallel subprocess workers (synchronous wrapper)."""
|
||||
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
|
||||
return self._collect_reward_workers(inputs, prompts, completions, completion_ids_list)
|
||||
|
||||
def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):
|
||||
"""Send reward work to subprocess workers (non-blocking).
|
||||
|
||||
Results are collected later by _collect_reward_workers, allowing GPU
|
||||
logprob computation to overlap with CPU reward computation.
|
||||
"""
|
||||
reward_can_bg = all(
|
||||
not isinstance(rf, nn.Module) and not asyncio.iscoroutinefunction(rf)
|
||||
for rf in self.reward_funcs
|
||||
@@ -347,12 +354,12 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
||||
num_workers = getattr(self.args, "reward_num_workers", 1)
|
||||
|
||||
if not reward_can_bg or num_workers <= 1:
|
||||
return self._calculate_rewards(
|
||||
inputs, prompts, completions, completion_ids_list
|
||||
)
|
||||
# Can't parallelize — store args for sync fallback in collect
|
||||
self._reward_workers_used = None
|
||||
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
|
||||
return
|
||||
|
||||
workers = self._get_reward_workers()
|
||||
device = self.accelerator.device
|
||||
num_generations = self.num_generations
|
||||
num_prompts = len(prompts)
|
||||
num_groups = num_prompts // num_generations
|
||||
@@ -379,7 +386,28 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
||||
)
|
||||
workers_used.append(conn)
|
||||
|
||||
# Collect results
|
||||
self._reward_workers_used = workers_used
|
||||
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
|
||||
|
||||
def _collect_reward_workers(self, inputs, prompts, completions, completion_ids_list):
|
||||
"""Collect reward results from subprocess workers (blocks until done)."""
|
||||
from accelerate.utils import gather
|
||||
|
||||
workers_used = getattr(self, "_reward_workers_used", None)
|
||||
args = getattr(self, "_pending_reward_args", None)
|
||||
self._reward_workers_used = None
|
||||
self._pending_reward_args = None
|
||||
|
||||
if workers_used is None:
|
||||
# Sync fallback — compute on main thread
|
||||
if args is not None:
|
||||
return self._calculate_rewards(*args)
|
||||
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
|
||||
|
||||
device = self.accelerator.device
|
||||
num_prompts = len(args[1]) if args else len(prompts)
|
||||
|
||||
# Collect results from workers
|
||||
all_worker_results = []
|
||||
any_failed = False
|
||||
for conn in workers_used:
|
||||
@@ -404,9 +432,9 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
||||
return gather(rewards_per_func)
|
||||
|
||||
# Fallback to main thread on failure
|
||||
return self._calculate_rewards(
|
||||
inputs, prompts, completions, completion_ids_list
|
||||
)
|
||||
if args is not None:
|
||||
return self._calculate_rewards(*args)
|
||||
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
|
||||
|
||||
def _post_advantage_hook(
|
||||
self,
|
||||
@@ -451,9 +479,12 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
||||
self._replay_buffer.add(replay_scores[gi].item(), group_data)
|
||||
|
||||
# Replace zero-signal groups with high-signal replay buffer entries
|
||||
if True:
|
||||
# Only in non-streaming path (s_start is None) — streaming scores
|
||||
# groups incrementally, so replacement + logprob recompute would be
|
||||
# too expensive per chunk.
|
||||
n_replaced = 0
|
||||
if s_start is None:
|
||||
no_signal = ~has_signal
|
||||
n_replaced = 0
|
||||
replaced_ranges = []
|
||||
if no_signal.any() and len(self._replay_buffer) > 0:
|
||||
for group_idx in no_signal.nonzero(as_tuple=True)[0]:
|
||||
|
||||
Reference in New Issue
Block a user