fix replay buffer
This commit is contained in:
@@ -567,9 +567,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
# Data producer (the proper architecture for async generation)
|
||||
self.data_producer = None
|
||||
if getattr(self.args, "use_data_producer", False):
|
||||
self.data_producer = self._create_data_producer(
|
||||
kwargs["args"], kwargs["train_dataset"]
|
||||
)
|
||||
self.data_producer = self._create_data_producer(kwargs["args"], kwargs["train_dataset"])
|
||||
|
||||
if self.args.async_prefetch and self.data_producer is None:
|
||||
# Legacy path: direct _prepare_inputs override without data producer
|
||||
@@ -590,17 +588,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
config=producer_config,
|
||||
prompt_dataset=train_dataset,
|
||||
num_generations=self.num_generations,
|
||||
generation_batch_size=getattr(
|
||||
args,
|
||||
"generation_batch_size",
|
||||
self._train_batch_size * args.gradient_accumulation_steps,
|
||||
),
|
||||
generation_batch_size=args.generation_batch_size,
|
||||
train_batch_size=args.per_device_train_batch_size,
|
||||
steps_per_generation=args.steps_per_generation,
|
||||
shuffle_dataset=getattr(self, "shuffle_dataset", True),
|
||||
seed=args.seed,
|
||||
)
|
||||
# Inject trainer reference (needs accelerator from super().__init__)
|
||||
data_producer.set_trainer(self)
|
||||
|
||||
if args.async_prefetch:
|
||||
@@ -868,6 +861,28 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
output[k] = forward_kwargs[k]
|
||||
return output
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Hooks (overridden by subclasses like FastAsyncGRPOTrainer)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compute_rewards_for_batch(self, inputs, prompts, completions, completion_ids_list):
|
||||
"""Compute rewards for a batch. Override for parallel workers, caching, etc."""
|
||||
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
|
||||
|
||||
def _post_advantage_hook(
|
||||
self,
|
||||
data: dict,
|
||||
rewards_per_func,
|
||||
advantages,
|
||||
inputs: list,
|
||||
num_generations: int,
|
||||
mode: str,
|
||||
s_start: int | None = None,
|
||||
s_end: int | None = None,
|
||||
is_last_chunk: bool = True,
|
||||
) -> None:
|
||||
"""Called after advantages are computed. Override for replay buffer, re-roll, etc."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main-thread scoring
|
||||
# ------------------------------------------------------------------
|
||||
@@ -999,7 +1014,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
data["importance_sampling_ratio"] = is_ratio
|
||||
|
||||
# --- Rewards ---
|
||||
rewards_per_func = self._calculate_rewards(
|
||||
rewards_per_func = self._compute_rewards_for_batch(
|
||||
inputs, prompts, completions, completion_ids_list
|
||||
)
|
||||
|
||||
@@ -1069,6 +1084,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
advantages = advantages[process_slice]
|
||||
data["advantages"] = advantages
|
||||
|
||||
# --- Post-advantage hook (for replay buffer, re-roll, etc.) ---
|
||||
self._post_advantage_hook(
|
||||
data, rewards_per_func, advantages, inputs, num_generations, mode,
|
||||
)
|
||||
|
||||
# --- Metrics ---
|
||||
for i, name in enumerate(self.reward_func_names):
|
||||
self._metrics[mode][f"rewards/{name}/mean"].append(
|
||||
@@ -1337,7 +1357,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
data["ref_per_token_logps"][s_start:s_end] = ref_logps
|
||||
|
||||
# --- Rewards ---
|
||||
rewards_per_func = self._calculate_rewards(
|
||||
rewards_per_func = self._compute_rewards_for_batch(
|
||||
inputs, prompts, completions, completion_ids_list
|
||||
)
|
||||
|
||||
@@ -1405,6 +1425,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
data["advantages"] = torch.zeros(len(data["prompt_ids"]), device=device)
|
||||
data["advantages"][s_start:s_end] = advantages
|
||||
|
||||
# --- Post-advantage hook (for replay buffer, re-roll, etc.) ---
|
||||
self._post_advantage_hook(
|
||||
data, rewards_per_func, advantages, inputs, num_generations, mode,
|
||||
s_start=s_start, s_end=s_end, is_last_chunk=is_last_chunk,
|
||||
)
|
||||
|
||||
# --- Chunk metrics ---
|
||||
for i, name in enumerate(self.reward_func_names):
|
||||
self._metrics[mode][f"rewards/{name}/mean"].append(
|
||||
|
||||
@@ -430,11 +430,11 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
||||
)
|
||||
per_group_std = local_grouped.std(dim=1)
|
||||
has_signal = (per_group_std > 0).any(dim=1)
|
||||
offset = s_start or 0
|
||||
|
||||
if has_signal.any():
|
||||
grouped_adv = advantages.view(-1, num_generations)
|
||||
replay_scores = grouped_adv.abs().sum(dim=1) * per_group_std.sum(dim=1)
|
||||
offset = s_start or 0
|
||||
for group_idx in has_signal.nonzero(as_tuple=True)[0]:
|
||||
gi = group_idx.item()
|
||||
start = offset + gi * num_generations
|
||||
@@ -450,8 +450,8 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
||||
group_data[key] = val[start:end].clone()
|
||||
self._replay_buffer.add(replay_scores[gi].item(), group_data)
|
||||
|
||||
# Replace zero-signal groups (only in deferred path, not streaming)
|
||||
if s_start is None:
|
||||
# Replace zero-signal groups with high-signal replay buffer entries
|
||||
if True:
|
||||
no_signal = ~has_signal
|
||||
n_replaced = 0
|
||||
replaced_ranges = []
|
||||
@@ -462,7 +462,8 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
||||
break
|
||||
sampled_group = sampled[0]
|
||||
gi = group_idx.item()
|
||||
start, end = gi * num_generations, (gi + 1) * num_generations
|
||||
start = offset + gi * num_generations
|
||||
end = start + num_generations
|
||||
for key, val in sampled_group.items():
|
||||
if key in data and isinstance(data[key], torch.Tensor):
|
||||
src = val.to(data[key].device)
|
||||
|
||||
Reference in New Issue
Block a user