diff --git a/src/axolotl/core/trainers/grpo/async_trainer.py b/src/axolotl/core/trainers/grpo/async_trainer.py index b12f40a4e..be1e987cc 100644 --- a/src/axolotl/core/trainers/grpo/async_trainer.py +++ b/src/axolotl/core/trainers/grpo/async_trainer.py @@ -25,7 +25,7 @@ import concurrent.futures import logging import queue import threading -import time + from abc import ABC, abstractmethod from collections import deque from contextlib import nullcontext @@ -562,7 +562,6 @@ class AsyncGRPOTrainer(GRPOTrainer): self._prompt_iter = None self._last_synced_step = -1 self._buffered_inputs: list | None = None # override stock attr - self._current_train_step_time = 0.0 # Data producer (the proper architecture for async generation) self.data_producer = None @@ -869,6 +868,24 @@ class AsyncGRPOTrainer(GRPOTrainer): """Compute rewards for a batch. Override for parallel workers, caching, etc.""" return self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list): + """Launch reward computation in background. Override for parallel dispatch. + + Default: no-op (rewards computed synchronously in _collect_reward_workers). + """ + self._pending_reward_args = (inputs, prompts, completions, completion_ids_list) + + def _collect_reward_workers(self, inputs, prompts, completions, completion_ids_list): + """Collect reward results. Override to collect from parallel workers. + + Default: compute rewards synchronously now. + """ + args = getattr(self, "_pending_reward_args", None) + if args is not None: + self._pending_reward_args = None + return self._compute_rewards_for_batch(*args) + return self._compute_rewards_for_batch(inputs, prompts, completions, completion_ids_list) + def _post_advantage_hook( self, data: dict, @@ -929,6 +946,9 @@ class AsyncGRPOTrainer(GRPOTrainer): forward_kwargs[key] = data[key] num_images = data.get("num_images") + # --- Launch rewards in parallel with logprobs --- + self._launch_reward_workers(inputs, prompts, completions, completion_ids_list) + # --- Policy logprobs --- logprob_batch_size = min(batch_size * 4, len(prompt_ids)) with disable_gradient_checkpointing( @@ -1013,8 +1033,8 @@ class AsyncGRPOTrainer(GRPOTrainer): is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0) data["importance_sampling_ratio"] = is_ratio - # --- Rewards --- - rewards_per_func = self._compute_rewards_for_batch( + # --- Collect rewards (launched before logprobs, should be done) --- + rewards_per_func = self._collect_reward_workers( inputs, prompts, completions, completion_ids_list ) @@ -1267,6 +1287,10 @@ class AsyncGRPOTrainer(GRPOTrainer): ): num_images = num_images[s_start:s_end] + # --- Launch rewards in parallel with logprobs --- + self._launch_reward_workers(inputs, prompts, completions, completion_ids_list) + + # --- Policy logprobs for this chunk (GPU, overlaps with BG rewards) --- logprob_batch_size = min(batch_size * 4, chunk_size) with disable_gradient_checkpointing( self.model, self.args.gradient_checkpointing_kwargs @@ -1356,8 +1380,8 @@ class AsyncGRPOTrainer(GRPOTrainer): ) data["ref_per_token_logps"][s_start:s_end] = ref_logps - # --- Rewards --- - rewards_per_func = self._compute_rewards_for_batch( + # --- Collect rewards (should already be done, ran in parallel with logprobs) --- + rewards_per_func = self._collect_reward_workers( inputs, prompts, completions, completion_ids_list ) @@ -1620,6 +1644,7 @@ class AsyncGRPOTrainer(GRPOTrainer): # Produce a new rollout self._maybe_sync_vllm_weights() + rollout_dataset = self.data_producer.produce( self.model, self.state.global_step, @@ -1728,7 +1753,6 @@ class AsyncGRPOTrainer(GRPOTrainer): token_type_ids=inputs.get("token_type_ids"), mm_token_type_ids=inputs.get("mm_token_type_ids"), ) - if self.top_entropy_quantile < 1.0: entropy_mask = self.get_high_entropy_mask( entropies, mask, 1 - self.top_entropy_quantile @@ -1910,12 +1934,6 @@ class AsyncGRPOTrainer(GRPOTrainer): # ------------------------------------------------------------------ def training_step(self, model, inputs, num_items_in_batch=None): - t0 = time.perf_counter() output = super().training_step(model, inputs, num_items_in_batch) self._step += 1 - t1 = time.perf_counter() - self._current_train_step_time += t1 - t0 - if self._step % self.current_gradient_accumulation_steps == 0: - self._metrics["train"]["step_time"].append(self._current_train_step_time) - self._current_train_step_time = 0.0 return output diff --git a/src/axolotl/core/trainers/grpo/fast_async_trainer.py b/src/axolotl/core/trainers/grpo/fast_async_trainer.py index 467e80858..a19e33c80 100644 --- a/src/axolotl/core/trainers/grpo/fast_async_trainer.py +++ b/src/axolotl/core/trainers/grpo/fast_async_trainer.py @@ -337,9 +337,16 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer): def _compute_rewards_for_batch( self, inputs, prompts, completions, completion_ids_list ): - """Dispatch rewards to parallel subprocess workers when possible.""" - from accelerate.utils import gather + """Dispatch rewards to parallel subprocess workers (synchronous wrapper).""" + self._launch_reward_workers(inputs, prompts, completions, completion_ids_list) + return self._collect_reward_workers(inputs, prompts, completions, completion_ids_list) + def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list): + """Send reward work to subprocess workers (non-blocking). + + Results are collected later by _collect_reward_workers, allowing GPU + logprob computation to overlap with CPU reward computation. + """ reward_can_bg = all( not isinstance(rf, nn.Module) and not asyncio.iscoroutinefunction(rf) for rf in self.reward_funcs @@ -347,12 +354,12 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer): num_workers = getattr(self.args, "reward_num_workers", 1) if not reward_can_bg or num_workers <= 1: - return self._calculate_rewards( - inputs, prompts, completions, completion_ids_list - ) + # Can't parallelize — store args for sync fallback in collect + self._reward_workers_used = None + self._pending_reward_args = (inputs, prompts, completions, completion_ids_list) + return workers = self._get_reward_workers() - device = self.accelerator.device num_generations = self.num_generations num_prompts = len(prompts) num_groups = num_prompts // num_generations @@ -379,7 +386,28 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer): ) workers_used.append(conn) - # Collect results + self._reward_workers_used = workers_used + self._pending_reward_args = (inputs, prompts, completions, completion_ids_list) + + def _collect_reward_workers(self, inputs, prompts, completions, completion_ids_list): + """Collect reward results from subprocess workers (blocks until done).""" + from accelerate.utils import gather + + workers_used = getattr(self, "_reward_workers_used", None) + args = getattr(self, "_pending_reward_args", None) + self._reward_workers_used = None + self._pending_reward_args = None + + if workers_used is None: + # Sync fallback — compute on main thread + if args is not None: + return self._calculate_rewards(*args) + return self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + + device = self.accelerator.device + num_prompts = len(args[1]) if args else len(prompts) + + # Collect results from workers all_worker_results = [] any_failed = False for conn in workers_used: @@ -404,9 +432,9 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer): return gather(rewards_per_func) # Fallback to main thread on failure - return self._calculate_rewards( - inputs, prompts, completions, completion_ids_list - ) + if args is not None: + return self._calculate_rewards(*args) + return self._calculate_rewards(inputs, prompts, completions, completion_ids_list) def _post_advantage_hook( self, @@ -451,9 +479,12 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer): self._replay_buffer.add(replay_scores[gi].item(), group_data) # Replace zero-signal groups with high-signal replay buffer entries - if True: + # Only in non-streaming path (s_start is None) — streaming scores + # groups incrementally, so replacement + logprob recompute would be + # too expensive per chunk. + n_replaced = 0 + if s_start is None: no_signal = ~has_signal - n_replaced = 0 replaced_ranges = [] if no_signal.any() and len(self._replay_buffer) > 0: for group_idx in no_signal.nonzero(as_tuple=True)[0]: