make gpus go brrr

This commit is contained in:
Wing Lian
2026-03-10 03:29:10 +00:00
parent bba1330e9b
commit c887057e5e
2 changed files with 74 additions and 25 deletions

View File

@@ -25,7 +25,7 @@ import concurrent.futures
import logging
import queue
import threading
import time
from abc import ABC, abstractmethod
from collections import deque
from contextlib import nullcontext
@@ -562,7 +562,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
self._prompt_iter = None
self._last_synced_step = -1
self._buffered_inputs: list | None = None # override stock attr
self._current_train_step_time = 0.0
# Data producer (the proper architecture for async generation)
self.data_producer = None
@@ -869,6 +868,24 @@ class AsyncGRPOTrainer(GRPOTrainer):
"""Compute rewards for a batch. Override for parallel workers, caching, etc."""
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):
"""Launch reward computation in background. Override for parallel dispatch.
Default: no-op (rewards computed synchronously in _collect_reward_workers).
"""
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
def _collect_reward_workers(self, inputs, prompts, completions, completion_ids_list):
"""Collect reward results. Override to collect from parallel workers.
Default: compute rewards synchronously now.
"""
args = getattr(self, "_pending_reward_args", None)
if args is not None:
self._pending_reward_args = None
return self._compute_rewards_for_batch(*args)
return self._compute_rewards_for_batch(inputs, prompts, completions, completion_ids_list)
def _post_advantage_hook(
self,
data: dict,
@@ -929,6 +946,9 @@ class AsyncGRPOTrainer(GRPOTrainer):
forward_kwargs[key] = data[key]
num_images = data.get("num_images")
# --- Launch rewards in parallel with logprobs ---
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
# --- Policy logprobs ---
logprob_batch_size = min(batch_size * 4, len(prompt_ids))
with disable_gradient_checkpointing(
@@ -1013,8 +1033,8 @@ class AsyncGRPOTrainer(GRPOTrainer):
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
data["importance_sampling_ratio"] = is_ratio
# --- Rewards ---
rewards_per_func = self._compute_rewards_for_batch(
# --- Collect rewards (launched before logprobs, should be done) ---
rewards_per_func = self._collect_reward_workers(
inputs, prompts, completions, completion_ids_list
)
@@ -1267,6 +1287,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
):
num_images = num_images[s_start:s_end]
# --- Launch rewards in parallel with logprobs ---
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
# --- Policy logprobs for this chunk (GPU, overlaps with BG rewards) ---
logprob_batch_size = min(batch_size * 4, chunk_size)
with disable_gradient_checkpointing(
self.model, self.args.gradient_checkpointing_kwargs
@@ -1356,8 +1380,8 @@ class AsyncGRPOTrainer(GRPOTrainer):
)
data["ref_per_token_logps"][s_start:s_end] = ref_logps
# --- Rewards ---
rewards_per_func = self._compute_rewards_for_batch(
# --- Collect rewards (should already be done, ran in parallel with logprobs) ---
rewards_per_func = self._collect_reward_workers(
inputs, prompts, completions, completion_ids_list
)
@@ -1620,6 +1644,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
# Produce a new rollout
self._maybe_sync_vllm_weights()
rollout_dataset = self.data_producer.produce(
self.model,
self.state.global_step,
@@ -1728,7 +1753,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
token_type_ids=inputs.get("token_type_ids"),
mm_token_type_ids=inputs.get("mm_token_type_ids"),
)
if self.top_entropy_quantile < 1.0:
entropy_mask = self.get_high_entropy_mask(
entropies, mask, 1 - self.top_entropy_quantile
@@ -1910,12 +1934,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
# ------------------------------------------------------------------
def training_step(self, model, inputs, num_items_in_batch=None):
t0 = time.perf_counter()
output = super().training_step(model, inputs, num_items_in_batch)
self._step += 1
t1 = time.perf_counter()
self._current_train_step_time += t1 - t0
if self._step % self.current_gradient_accumulation_steps == 0:
self._metrics["train"]["step_time"].append(self._current_train_step_time)
self._current_train_step_time = 0.0
return output

View File

@@ -337,9 +337,16 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
def _compute_rewards_for_batch(
self, inputs, prompts, completions, completion_ids_list
):
"""Dispatch rewards to parallel subprocess workers when possible."""
from accelerate.utils import gather
"""Dispatch rewards to parallel subprocess workers (synchronous wrapper)."""
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
return self._collect_reward_workers(inputs, prompts, completions, completion_ids_list)
def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):
"""Send reward work to subprocess workers (non-blocking).
Results are collected later by _collect_reward_workers, allowing GPU
logprob computation to overlap with CPU reward computation.
"""
reward_can_bg = all(
not isinstance(rf, nn.Module) and not asyncio.iscoroutinefunction(rf)
for rf in self.reward_funcs
@@ -347,12 +354,12 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
num_workers = getattr(self.args, "reward_num_workers", 1)
if not reward_can_bg or num_workers <= 1:
return self._calculate_rewards(
inputs, prompts, completions, completion_ids_list
)
# Can't parallelize — store args for sync fallback in collect
self._reward_workers_used = None
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
return
workers = self._get_reward_workers()
device = self.accelerator.device
num_generations = self.num_generations
num_prompts = len(prompts)
num_groups = num_prompts // num_generations
@@ -379,7 +386,28 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
)
workers_used.append(conn)
# Collect results
self._reward_workers_used = workers_used
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
def _collect_reward_workers(self, inputs, prompts, completions, completion_ids_list):
"""Collect reward results from subprocess workers (blocks until done)."""
from accelerate.utils import gather
workers_used = getattr(self, "_reward_workers_used", None)
args = getattr(self, "_pending_reward_args", None)
self._reward_workers_used = None
self._pending_reward_args = None
if workers_used is None:
# Sync fallback — compute on main thread
if args is not None:
return self._calculate_rewards(*args)
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
device = self.accelerator.device
num_prompts = len(args[1]) if args else len(prompts)
# Collect results from workers
all_worker_results = []
any_failed = False
for conn in workers_used:
@@ -404,9 +432,9 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
return gather(rewards_per_func)
# Fallback to main thread on failure
return self._calculate_rewards(
inputs, prompts, completions, completion_ids_list
)
if args is not None:
return self._calculate_rewards(*args)
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
def _post_advantage_hook(
self,
@@ -451,9 +479,12 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
self._replay_buffer.add(replay_scores[gi].item(), group_data)
# Replace zero-signal groups with high-signal replay buffer entries
if True:
# Only in non-streaming path (s_start is None) — streaming scores
# groups incrementally, so replacement + logprob recompute would be
# too expensive per chunk.
n_replaced = 0
if s_start is None:
no_signal = ~has_signal
n_replaced = 0
replaced_ranges = []
if no_signal.any() and len(self._replay_buffer) > 0:
for group_idx in no_signal.nonzero(as_tuple=True)[0]: