From bba1330e9ba70ce36f38dd57e343e6bf546a1960 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 10 Mar 2026 02:18:43 +0000 Subject: [PATCH] fix replay buffer --- .../core/trainers/grpo/async_trainer.py | 48 ++++++++++++++----- .../core/trainers/grpo/fast_async_trainer.py | 9 ++-- 2 files changed, 42 insertions(+), 15 deletions(-) diff --git a/src/axolotl/core/trainers/grpo/async_trainer.py b/src/axolotl/core/trainers/grpo/async_trainer.py index 3af8c4a28..b12f40a4e 100644 --- a/src/axolotl/core/trainers/grpo/async_trainer.py +++ b/src/axolotl/core/trainers/grpo/async_trainer.py @@ -567,9 +567,7 @@ class AsyncGRPOTrainer(GRPOTrainer): # Data producer (the proper architecture for async generation) self.data_producer = None if getattr(self.args, "use_data_producer", False): - self.data_producer = self._create_data_producer( - kwargs["args"], kwargs["train_dataset"] - ) + self.data_producer = self._create_data_producer(kwargs["args"], kwargs["train_dataset"]) if self.args.async_prefetch and self.data_producer is None: # Legacy path: direct _prepare_inputs override without data producer @@ -590,17 +588,12 @@ class AsyncGRPOTrainer(GRPOTrainer): config=producer_config, prompt_dataset=train_dataset, num_generations=self.num_generations, - generation_batch_size=getattr( - args, - "generation_batch_size", - self._train_batch_size * args.gradient_accumulation_steps, - ), + generation_batch_size=args.generation_batch_size, train_batch_size=args.per_device_train_batch_size, steps_per_generation=args.steps_per_generation, shuffle_dataset=getattr(self, "shuffle_dataset", True), seed=args.seed, ) - # Inject trainer reference (needs accelerator from super().__init__) data_producer.set_trainer(self) if args.async_prefetch: @@ -868,6 +861,28 @@ class AsyncGRPOTrainer(GRPOTrainer): output[k] = forward_kwargs[k] return output + # ------------------------------------------------------------------ + # Hooks (overridden by subclasses like FastAsyncGRPOTrainer) + # ------------------------------------------------------------------ + + def _compute_rewards_for_batch(self, inputs, prompts, completions, completion_ids_list): + """Compute rewards for a batch. Override for parallel workers, caching, etc.""" + return self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + + def _post_advantage_hook( + self, + data: dict, + rewards_per_func, + advantages, + inputs: list, + num_generations: int, + mode: str, + s_start: int | None = None, + s_end: int | None = None, + is_last_chunk: bool = True, + ) -> None: + """Called after advantages are computed. Override for replay buffer, re-roll, etc.""" + # ------------------------------------------------------------------ # Main-thread scoring # ------------------------------------------------------------------ @@ -999,7 +1014,7 @@ class AsyncGRPOTrainer(GRPOTrainer): data["importance_sampling_ratio"] = is_ratio # --- Rewards --- - rewards_per_func = self._calculate_rewards( + rewards_per_func = self._compute_rewards_for_batch( inputs, prompts, completions, completion_ids_list ) @@ -1069,6 +1084,11 @@ class AsyncGRPOTrainer(GRPOTrainer): advantages = advantages[process_slice] data["advantages"] = advantages + # --- Post-advantage hook (for replay buffer, re-roll, etc.) --- + self._post_advantage_hook( + data, rewards_per_func, advantages, inputs, num_generations, mode, + ) + # --- Metrics --- for i, name in enumerate(self.reward_func_names): self._metrics[mode][f"rewards/{name}/mean"].append( @@ -1337,7 +1357,7 @@ class AsyncGRPOTrainer(GRPOTrainer): data["ref_per_token_logps"][s_start:s_end] = ref_logps # --- Rewards --- - rewards_per_func = self._calculate_rewards( + rewards_per_func = self._compute_rewards_for_batch( inputs, prompts, completions, completion_ids_list ) @@ -1405,6 +1425,12 @@ class AsyncGRPOTrainer(GRPOTrainer): data["advantages"] = torch.zeros(len(data["prompt_ids"]), device=device) data["advantages"][s_start:s_end] = advantages + # --- Post-advantage hook (for replay buffer, re-roll, etc.) --- + self._post_advantage_hook( + data, rewards_per_func, advantages, inputs, num_generations, mode, + s_start=s_start, s_end=s_end, is_last_chunk=is_last_chunk, + ) + # --- Chunk metrics --- for i, name in enumerate(self.reward_func_names): self._metrics[mode][f"rewards/{name}/mean"].append( diff --git a/src/axolotl/core/trainers/grpo/fast_async_trainer.py b/src/axolotl/core/trainers/grpo/fast_async_trainer.py index 3367d05e8..467e80858 100644 --- a/src/axolotl/core/trainers/grpo/fast_async_trainer.py +++ b/src/axolotl/core/trainers/grpo/fast_async_trainer.py @@ -430,11 +430,11 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer): ) per_group_std = local_grouped.std(dim=1) has_signal = (per_group_std > 0).any(dim=1) + offset = s_start or 0 if has_signal.any(): grouped_adv = advantages.view(-1, num_generations) replay_scores = grouped_adv.abs().sum(dim=1) * per_group_std.sum(dim=1) - offset = s_start or 0 for group_idx in has_signal.nonzero(as_tuple=True)[0]: gi = group_idx.item() start = offset + gi * num_generations @@ -450,8 +450,8 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer): group_data[key] = val[start:end].clone() self._replay_buffer.add(replay_scores[gi].item(), group_data) - # Replace zero-signal groups (only in deferred path, not streaming) - if s_start is None: + # Replace zero-signal groups with high-signal replay buffer entries + if True: no_signal = ~has_signal n_replaced = 0 replaced_ranges = [] @@ -462,7 +462,8 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer): break sampled_group = sampled[0] gi = group_idx.item() - start, end = gi * num_generations, (gi + 1) * num_generations + start = offset + gi * num_generations + end = start + num_generations for key, val in sampled_group.items(): if key in data and isinstance(data[key], torch.Tensor): src = val.to(data[key].device)