fix replay buffer

This commit is contained in:
Wing Lian
2026-03-10 02:18:43 +00:00
parent 9394d17f28
commit bba1330e9b
2 changed files with 42 additions and 15 deletions

View File

@@ -567,9 +567,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
# Data producer (the proper architecture for async generation)
self.data_producer = None
if getattr(self.args, "use_data_producer", False):
self.data_producer = self._create_data_producer(
kwargs["args"], kwargs["train_dataset"]
)
self.data_producer = self._create_data_producer(kwargs["args"], kwargs["train_dataset"])
if self.args.async_prefetch and self.data_producer is None:
# Legacy path: direct _prepare_inputs override without data producer
@@ -590,17 +588,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
config=producer_config,
prompt_dataset=train_dataset,
num_generations=self.num_generations,
generation_batch_size=getattr(
args,
"generation_batch_size",
self._train_batch_size * args.gradient_accumulation_steps,
),
generation_batch_size=args.generation_batch_size,
train_batch_size=args.per_device_train_batch_size,
steps_per_generation=args.steps_per_generation,
shuffle_dataset=getattr(self, "shuffle_dataset", True),
seed=args.seed,
)
# Inject trainer reference (needs accelerator from super().__init__)
data_producer.set_trainer(self)
if args.async_prefetch:
@@ -868,6 +861,28 @@ class AsyncGRPOTrainer(GRPOTrainer):
output[k] = forward_kwargs[k]
return output
# ------------------------------------------------------------------
# Hooks (overridden by subclasses like FastAsyncGRPOTrainer)
# ------------------------------------------------------------------
def _compute_rewards_for_batch(self, inputs, prompts, completions, completion_ids_list):
"""Compute rewards for a batch. Override for parallel workers, caching, etc."""
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
def _post_advantage_hook(
self,
data: dict,
rewards_per_func,
advantages,
inputs: list,
num_generations: int,
mode: str,
s_start: int | None = None,
s_end: int | None = None,
is_last_chunk: bool = True,
) -> None:
"""Called after advantages are computed. Override for replay buffer, re-roll, etc."""
# ------------------------------------------------------------------
# Main-thread scoring
# ------------------------------------------------------------------
@@ -999,7 +1014,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
data["importance_sampling_ratio"] = is_ratio
# --- Rewards ---
rewards_per_func = self._calculate_rewards(
rewards_per_func = self._compute_rewards_for_batch(
inputs, prompts, completions, completion_ids_list
)
@@ -1069,6 +1084,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
advantages = advantages[process_slice]
data["advantages"] = advantages
# --- Post-advantage hook (for replay buffer, re-roll, etc.) ---
self._post_advantage_hook(
data, rewards_per_func, advantages, inputs, num_generations, mode,
)
# --- Metrics ---
for i, name in enumerate(self.reward_func_names):
self._metrics[mode][f"rewards/{name}/mean"].append(
@@ -1337,7 +1357,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
data["ref_per_token_logps"][s_start:s_end] = ref_logps
# --- Rewards ---
rewards_per_func = self._calculate_rewards(
rewards_per_func = self._compute_rewards_for_batch(
inputs, prompts, completions, completion_ids_list
)
@@ -1405,6 +1425,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
data["advantages"] = torch.zeros(len(data["prompt_ids"]), device=device)
data["advantages"][s_start:s_end] = advantages
# --- Post-advantage hook (for replay buffer, re-roll, etc.) ---
self._post_advantage_hook(
data, rewards_per_func, advantages, inputs, num_generations, mode,
s_start=s_start, s_end=s_end, is_last_chunk=is_last_chunk,
)
# --- Chunk metrics ---
for i, name in enumerate(self.reward_func_names):
self._metrics[mode][f"rewards/{name}/mean"].append(

View File

@@ -430,11 +430,11 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
)
per_group_std = local_grouped.std(dim=1)
has_signal = (per_group_std > 0).any(dim=1)
offset = s_start or 0
if has_signal.any():
grouped_adv = advantages.view(-1, num_generations)
replay_scores = grouped_adv.abs().sum(dim=1) * per_group_std.sum(dim=1)
offset = s_start or 0
for group_idx in has_signal.nonzero(as_tuple=True)[0]:
gi = group_idx.item()
start = offset + gi * num_generations
@@ -450,8 +450,8 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
group_data[key] = val[start:end].clone()
self._replay_buffer.add(replay_scores[gi].item(), group_data)
# Replace zero-signal groups (only in deferred path, not streaming)
if s_start is None:
# Replace zero-signal groups with high-signal replay buffer entries
if True:
no_signal = ~has_signal
n_replaced = 0
replaced_ranges = []
@@ -462,7 +462,8 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
break
sampled_group = sampled[0]
gi = group_idx.item()
start, end = gi * num_generations, (gi + 1) * num_generations
start = offset + gi * num_generations
end = start + num_generations
for key, val in sampled_group.items():
if key in data and isinstance(data[key], torch.Tensor):
src = val.to(data[key].device)