make gpus go brrr
This commit is contained in:
@@ -25,7 +25,7 @@ import concurrent.futures
|
|||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
@@ -562,7 +562,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
self._prompt_iter = None
|
self._prompt_iter = None
|
||||||
self._last_synced_step = -1
|
self._last_synced_step = -1
|
||||||
self._buffered_inputs: list | None = None # override stock attr
|
self._buffered_inputs: list | None = None # override stock attr
|
||||||
self._current_train_step_time = 0.0
|
|
||||||
|
|
||||||
# Data producer (the proper architecture for async generation)
|
# Data producer (the proper architecture for async generation)
|
||||||
self.data_producer = None
|
self.data_producer = None
|
||||||
@@ -869,6 +868,24 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
"""Compute rewards for a batch. Override for parallel workers, caching, etc."""
|
"""Compute rewards for a batch. Override for parallel workers, caching, etc."""
|
||||||
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
|
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
|
||||||
|
|
||||||
|
def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):
|
||||||
|
"""Launch reward computation in background. Override for parallel dispatch.
|
||||||
|
|
||||||
|
Default: no-op (rewards computed synchronously in _collect_reward_workers).
|
||||||
|
"""
|
||||||
|
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
|
||||||
|
|
||||||
|
def _collect_reward_workers(self, inputs, prompts, completions, completion_ids_list):
|
||||||
|
"""Collect reward results. Override to collect from parallel workers.
|
||||||
|
|
||||||
|
Default: compute rewards synchronously now.
|
||||||
|
"""
|
||||||
|
args = getattr(self, "_pending_reward_args", None)
|
||||||
|
if args is not None:
|
||||||
|
self._pending_reward_args = None
|
||||||
|
return self._compute_rewards_for_batch(*args)
|
||||||
|
return self._compute_rewards_for_batch(inputs, prompts, completions, completion_ids_list)
|
||||||
|
|
||||||
def _post_advantage_hook(
|
def _post_advantage_hook(
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
@@ -929,6 +946,9 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
forward_kwargs[key] = data[key]
|
forward_kwargs[key] = data[key]
|
||||||
num_images = data.get("num_images")
|
num_images = data.get("num_images")
|
||||||
|
|
||||||
|
# --- Launch rewards in parallel with logprobs ---
|
||||||
|
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
|
||||||
|
|
||||||
# --- Policy logprobs ---
|
# --- Policy logprobs ---
|
||||||
logprob_batch_size = min(batch_size * 4, len(prompt_ids))
|
logprob_batch_size = min(batch_size * 4, len(prompt_ids))
|
||||||
with disable_gradient_checkpointing(
|
with disable_gradient_checkpointing(
|
||||||
@@ -1013,8 +1033,8 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
|
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
|
||||||
data["importance_sampling_ratio"] = is_ratio
|
data["importance_sampling_ratio"] = is_ratio
|
||||||
|
|
||||||
# --- Rewards ---
|
# --- Collect rewards (launched before logprobs, should be done) ---
|
||||||
rewards_per_func = self._compute_rewards_for_batch(
|
rewards_per_func = self._collect_reward_workers(
|
||||||
inputs, prompts, completions, completion_ids_list
|
inputs, prompts, completions, completion_ids_list
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1267,6 +1287,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
):
|
):
|
||||||
num_images = num_images[s_start:s_end]
|
num_images = num_images[s_start:s_end]
|
||||||
|
|
||||||
|
# --- Launch rewards in parallel with logprobs ---
|
||||||
|
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
|
||||||
|
|
||||||
|
# --- Policy logprobs for this chunk (GPU, overlaps with BG rewards) ---
|
||||||
logprob_batch_size = min(batch_size * 4, chunk_size)
|
logprob_batch_size = min(batch_size * 4, chunk_size)
|
||||||
with disable_gradient_checkpointing(
|
with disable_gradient_checkpointing(
|
||||||
self.model, self.args.gradient_checkpointing_kwargs
|
self.model, self.args.gradient_checkpointing_kwargs
|
||||||
@@ -1356,8 +1380,8 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
)
|
)
|
||||||
data["ref_per_token_logps"][s_start:s_end] = ref_logps
|
data["ref_per_token_logps"][s_start:s_end] = ref_logps
|
||||||
|
|
||||||
# --- Rewards ---
|
# --- Collect rewards (should already be done, ran in parallel with logprobs) ---
|
||||||
rewards_per_func = self._compute_rewards_for_batch(
|
rewards_per_func = self._collect_reward_workers(
|
||||||
inputs, prompts, completions, completion_ids_list
|
inputs, prompts, completions, completion_ids_list
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1620,6 +1644,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
|
|
||||||
# Produce a new rollout
|
# Produce a new rollout
|
||||||
self._maybe_sync_vllm_weights()
|
self._maybe_sync_vllm_weights()
|
||||||
|
|
||||||
rollout_dataset = self.data_producer.produce(
|
rollout_dataset = self.data_producer.produce(
|
||||||
self.model,
|
self.model,
|
||||||
self.state.global_step,
|
self.state.global_step,
|
||||||
@@ -1728,7 +1753,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
token_type_ids=inputs.get("token_type_ids"),
|
token_type_ids=inputs.get("token_type_ids"),
|
||||||
mm_token_type_ids=inputs.get("mm_token_type_ids"),
|
mm_token_type_ids=inputs.get("mm_token_type_ids"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.top_entropy_quantile < 1.0:
|
if self.top_entropy_quantile < 1.0:
|
||||||
entropy_mask = self.get_high_entropy_mask(
|
entropy_mask = self.get_high_entropy_mask(
|
||||||
entropies, mask, 1 - self.top_entropy_quantile
|
entropies, mask, 1 - self.top_entropy_quantile
|
||||||
@@ -1910,12 +1934,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def training_step(self, model, inputs, num_items_in_batch=None):
|
def training_step(self, model, inputs, num_items_in_batch=None):
|
||||||
t0 = time.perf_counter()
|
|
||||||
output = super().training_step(model, inputs, num_items_in_batch)
|
output = super().training_step(model, inputs, num_items_in_batch)
|
||||||
self._step += 1
|
self._step += 1
|
||||||
t1 = time.perf_counter()
|
|
||||||
self._current_train_step_time += t1 - t0
|
|
||||||
if self._step % self.current_gradient_accumulation_steps == 0:
|
|
||||||
self._metrics["train"]["step_time"].append(self._current_train_step_time)
|
|
||||||
self._current_train_step_time = 0.0
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -337,9 +337,16 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
|||||||
def _compute_rewards_for_batch(
|
def _compute_rewards_for_batch(
|
||||||
self, inputs, prompts, completions, completion_ids_list
|
self, inputs, prompts, completions, completion_ids_list
|
||||||
):
|
):
|
||||||
"""Dispatch rewards to parallel subprocess workers when possible."""
|
"""Dispatch rewards to parallel subprocess workers (synchronous wrapper)."""
|
||||||
from accelerate.utils import gather
|
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
|
||||||
|
return self._collect_reward_workers(inputs, prompts, completions, completion_ids_list)
|
||||||
|
|
||||||
|
def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):
|
||||||
|
"""Send reward work to subprocess workers (non-blocking).
|
||||||
|
|
||||||
|
Results are collected later by _collect_reward_workers, allowing GPU
|
||||||
|
logprob computation to overlap with CPU reward computation.
|
||||||
|
"""
|
||||||
reward_can_bg = all(
|
reward_can_bg = all(
|
||||||
not isinstance(rf, nn.Module) and not asyncio.iscoroutinefunction(rf)
|
not isinstance(rf, nn.Module) and not asyncio.iscoroutinefunction(rf)
|
||||||
for rf in self.reward_funcs
|
for rf in self.reward_funcs
|
||||||
@@ -347,12 +354,12 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
|||||||
num_workers = getattr(self.args, "reward_num_workers", 1)
|
num_workers = getattr(self.args, "reward_num_workers", 1)
|
||||||
|
|
||||||
if not reward_can_bg or num_workers <= 1:
|
if not reward_can_bg or num_workers <= 1:
|
||||||
return self._calculate_rewards(
|
# Can't parallelize — store args for sync fallback in collect
|
||||||
inputs, prompts, completions, completion_ids_list
|
self._reward_workers_used = None
|
||||||
)
|
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
|
||||||
|
return
|
||||||
|
|
||||||
workers = self._get_reward_workers()
|
workers = self._get_reward_workers()
|
||||||
device = self.accelerator.device
|
|
||||||
num_generations = self.num_generations
|
num_generations = self.num_generations
|
||||||
num_prompts = len(prompts)
|
num_prompts = len(prompts)
|
||||||
num_groups = num_prompts // num_generations
|
num_groups = num_prompts // num_generations
|
||||||
@@ -379,7 +386,28 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
|||||||
)
|
)
|
||||||
workers_used.append(conn)
|
workers_used.append(conn)
|
||||||
|
|
||||||
# Collect results
|
self._reward_workers_used = workers_used
|
||||||
|
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
|
||||||
|
|
||||||
|
def _collect_reward_workers(self, inputs, prompts, completions, completion_ids_list):
|
||||||
|
"""Collect reward results from subprocess workers (blocks until done)."""
|
||||||
|
from accelerate.utils import gather
|
||||||
|
|
||||||
|
workers_used = getattr(self, "_reward_workers_used", None)
|
||||||
|
args = getattr(self, "_pending_reward_args", None)
|
||||||
|
self._reward_workers_used = None
|
||||||
|
self._pending_reward_args = None
|
||||||
|
|
||||||
|
if workers_used is None:
|
||||||
|
# Sync fallback — compute on main thread
|
||||||
|
if args is not None:
|
||||||
|
return self._calculate_rewards(*args)
|
||||||
|
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
|
||||||
|
|
||||||
|
device = self.accelerator.device
|
||||||
|
num_prompts = len(args[1]) if args else len(prompts)
|
||||||
|
|
||||||
|
# Collect results from workers
|
||||||
all_worker_results = []
|
all_worker_results = []
|
||||||
any_failed = False
|
any_failed = False
|
||||||
for conn in workers_used:
|
for conn in workers_used:
|
||||||
@@ -404,9 +432,9 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
|||||||
return gather(rewards_per_func)
|
return gather(rewards_per_func)
|
||||||
|
|
||||||
# Fallback to main thread on failure
|
# Fallback to main thread on failure
|
||||||
return self._calculate_rewards(
|
if args is not None:
|
||||||
inputs, prompts, completions, completion_ids_list
|
return self._calculate_rewards(*args)
|
||||||
)
|
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
|
||||||
|
|
||||||
def _post_advantage_hook(
|
def _post_advantage_hook(
|
||||||
self,
|
self,
|
||||||
@@ -451,9 +479,12 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
|||||||
self._replay_buffer.add(replay_scores[gi].item(), group_data)
|
self._replay_buffer.add(replay_scores[gi].item(), group_data)
|
||||||
|
|
||||||
# Replace zero-signal groups with high-signal replay buffer entries
|
# Replace zero-signal groups with high-signal replay buffer entries
|
||||||
if True:
|
# Only in non-streaming path (s_start is None) — streaming scores
|
||||||
|
# groups incrementally, so replacement + logprob recompute would be
|
||||||
|
# too expensive per chunk.
|
||||||
|
n_replaced = 0
|
||||||
|
if s_start is None:
|
||||||
no_signal = ~has_signal
|
no_signal = ~has_signal
|
||||||
n_replaced = 0
|
|
||||||
replaced_ranges = []
|
replaced_ranges = []
|
||||||
if no_signal.any() and len(self._replay_buffer) > 0:
|
if no_signal.any() and len(self._replay_buffer) > 0:
|
||||||
for group_idx in no_signal.nonzero(as_tuple=True)[0]:
|
for group_idx in no_signal.nonzero(as_tuple=True)[0]:
|
||||||
|
|||||||
Reference in New Issue
Block a user