make gpus go brrr

This commit is contained in:
Wing Lian
2026-03-10 03:29:10 +00:00
parent bba1330e9b
commit c887057e5e
2 changed files with 74 additions and 25 deletions

View File

@@ -25,7 +25,7 @@ import concurrent.futures
import logging import logging
import queue import queue
import threading import threading
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import deque from collections import deque
from contextlib import nullcontext from contextlib import nullcontext
@@ -562,7 +562,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
self._prompt_iter = None self._prompt_iter = None
self._last_synced_step = -1 self._last_synced_step = -1
self._buffered_inputs: list | None = None # override stock attr self._buffered_inputs: list | None = None # override stock attr
self._current_train_step_time = 0.0
# Data producer (the proper architecture for async generation) # Data producer (the proper architecture for async generation)
self.data_producer = None self.data_producer = None
@@ -869,6 +868,24 @@ class AsyncGRPOTrainer(GRPOTrainer):
"""Compute rewards for a batch. Override for parallel workers, caching, etc.""" """Compute rewards for a batch. Override for parallel workers, caching, etc."""
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list) return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):
"""Launch reward computation in background. Override for parallel dispatch.
Default: no-op (rewards computed synchronously in _collect_reward_workers).
"""
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
def _collect_reward_workers(self, inputs, prompts, completions, completion_ids_list):
"""Collect reward results. Override to collect from parallel workers.
Default: compute rewards synchronously now.
"""
args = getattr(self, "_pending_reward_args", None)
if args is not None:
self._pending_reward_args = None
return self._compute_rewards_for_batch(*args)
return self._compute_rewards_for_batch(inputs, prompts, completions, completion_ids_list)
def _post_advantage_hook( def _post_advantage_hook(
self, self,
data: dict, data: dict,
@@ -929,6 +946,9 @@ class AsyncGRPOTrainer(GRPOTrainer):
forward_kwargs[key] = data[key] forward_kwargs[key] = data[key]
num_images = data.get("num_images") num_images = data.get("num_images")
# --- Launch rewards in parallel with logprobs ---
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
# --- Policy logprobs --- # --- Policy logprobs ---
logprob_batch_size = min(batch_size * 4, len(prompt_ids)) logprob_batch_size = min(batch_size * 4, len(prompt_ids))
with disable_gradient_checkpointing( with disable_gradient_checkpointing(
@@ -1013,8 +1033,8 @@ class AsyncGRPOTrainer(GRPOTrainer):
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0) is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
data["importance_sampling_ratio"] = is_ratio data["importance_sampling_ratio"] = is_ratio
# --- Rewards --- # --- Collect rewards (launched before logprobs, should be done) ---
rewards_per_func = self._compute_rewards_for_batch( rewards_per_func = self._collect_reward_workers(
inputs, prompts, completions, completion_ids_list inputs, prompts, completions, completion_ids_list
) )
@@ -1267,6 +1287,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
): ):
num_images = num_images[s_start:s_end] num_images = num_images[s_start:s_end]
# --- Launch rewards in parallel with logprobs ---
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
# --- Policy logprobs for this chunk (GPU, overlaps with BG rewards) ---
logprob_batch_size = min(batch_size * 4, chunk_size) logprob_batch_size = min(batch_size * 4, chunk_size)
with disable_gradient_checkpointing( with disable_gradient_checkpointing(
self.model, self.args.gradient_checkpointing_kwargs self.model, self.args.gradient_checkpointing_kwargs
@@ -1356,8 +1380,8 @@ class AsyncGRPOTrainer(GRPOTrainer):
) )
data["ref_per_token_logps"][s_start:s_end] = ref_logps data["ref_per_token_logps"][s_start:s_end] = ref_logps
# --- Rewards --- # --- Collect rewards (should already be done, ran in parallel with logprobs) ---
rewards_per_func = self._compute_rewards_for_batch( rewards_per_func = self._collect_reward_workers(
inputs, prompts, completions, completion_ids_list inputs, prompts, completions, completion_ids_list
) )
@@ -1620,6 +1644,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
# Produce a new rollout # Produce a new rollout
self._maybe_sync_vllm_weights() self._maybe_sync_vllm_weights()
rollout_dataset = self.data_producer.produce( rollout_dataset = self.data_producer.produce(
self.model, self.model,
self.state.global_step, self.state.global_step,
@@ -1728,7 +1753,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
token_type_ids=inputs.get("token_type_ids"), token_type_ids=inputs.get("token_type_ids"),
mm_token_type_ids=inputs.get("mm_token_type_ids"), mm_token_type_ids=inputs.get("mm_token_type_ids"),
) )
if self.top_entropy_quantile < 1.0: if self.top_entropy_quantile < 1.0:
entropy_mask = self.get_high_entropy_mask( entropy_mask = self.get_high_entropy_mask(
entropies, mask, 1 - self.top_entropy_quantile entropies, mask, 1 - self.top_entropy_quantile
@@ -1910,12 +1934,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def training_step(self, model, inputs, num_items_in_batch=None): def training_step(self, model, inputs, num_items_in_batch=None):
t0 = time.perf_counter()
output = super().training_step(model, inputs, num_items_in_batch) output = super().training_step(model, inputs, num_items_in_batch)
self._step += 1 self._step += 1
t1 = time.perf_counter()
self._current_train_step_time += t1 - t0
if self._step % self.current_gradient_accumulation_steps == 0:
self._metrics["train"]["step_time"].append(self._current_train_step_time)
self._current_train_step_time = 0.0
return output return output

View File

@@ -337,9 +337,16 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
def _compute_rewards_for_batch( def _compute_rewards_for_batch(
self, inputs, prompts, completions, completion_ids_list self, inputs, prompts, completions, completion_ids_list
): ):
"""Dispatch rewards to parallel subprocess workers when possible.""" """Dispatch rewards to parallel subprocess workers (synchronous wrapper)."""
from accelerate.utils import gather self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
return self._collect_reward_workers(inputs, prompts, completions, completion_ids_list)
def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):
"""Send reward work to subprocess workers (non-blocking).
Results are collected later by _collect_reward_workers, allowing GPU
logprob computation to overlap with CPU reward computation.
"""
reward_can_bg = all( reward_can_bg = all(
not isinstance(rf, nn.Module) and not asyncio.iscoroutinefunction(rf) not isinstance(rf, nn.Module) and not asyncio.iscoroutinefunction(rf)
for rf in self.reward_funcs for rf in self.reward_funcs
@@ -347,12 +354,12 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
num_workers = getattr(self.args, "reward_num_workers", 1) num_workers = getattr(self.args, "reward_num_workers", 1)
if not reward_can_bg or num_workers <= 1: if not reward_can_bg or num_workers <= 1:
return self._calculate_rewards( # Can't parallelize — store args for sync fallback in collect
inputs, prompts, completions, completion_ids_list self._reward_workers_used = None
) self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
return
workers = self._get_reward_workers() workers = self._get_reward_workers()
device = self.accelerator.device
num_generations = self.num_generations num_generations = self.num_generations
num_prompts = len(prompts) num_prompts = len(prompts)
num_groups = num_prompts // num_generations num_groups = num_prompts // num_generations
@@ -379,7 +386,28 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
) )
workers_used.append(conn) workers_used.append(conn)
# Collect results self._reward_workers_used = workers_used
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
def _collect_reward_workers(self, inputs, prompts, completions, completion_ids_list):
"""Collect reward results from subprocess workers (blocks until done)."""
from accelerate.utils import gather
workers_used = getattr(self, "_reward_workers_used", None)
args = getattr(self, "_pending_reward_args", None)
self._reward_workers_used = None
self._pending_reward_args = None
if workers_used is None:
# Sync fallback — compute on main thread
if args is not None:
return self._calculate_rewards(*args)
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
device = self.accelerator.device
num_prompts = len(args[1]) if args else len(prompts)
# Collect results from workers
all_worker_results = [] all_worker_results = []
any_failed = False any_failed = False
for conn in workers_used: for conn in workers_used:
@@ -404,9 +432,9 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
return gather(rewards_per_func) return gather(rewards_per_func)
# Fallback to main thread on failure # Fallback to main thread on failure
return self._calculate_rewards( if args is not None:
inputs, prompts, completions, completion_ids_list return self._calculate_rewards(*args)
) return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
def _post_advantage_hook( def _post_advantage_hook(
self, self,
@@ -451,9 +479,12 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
self._replay_buffer.add(replay_scores[gi].item(), group_data) self._replay_buffer.add(replay_scores[gi].item(), group_data)
# Replace zero-signal groups with high-signal replay buffer entries # Replace zero-signal groups with high-signal replay buffer entries
if True: # Only in non-streaming path (s_start is None) — streaming scores
# groups incrementally, so replacement + logprob recompute would be
# too expensive per chunk.
n_replaced = 0
if s_start is None:
no_signal = ~has_signal no_signal = ~has_signal
n_replaced = 0
replaced_ranges = [] replaced_ranges = []
if no_signal.any() and len(self._replay_buffer) > 0: if no_signal.any() and len(self._replay_buffer) > 0:
for group_idx in no_signal.nonzero(as_tuple=True)[0]: for group_idx in no_signal.nonzero(as_tuple=True)[0]: