fix replay buffer
This commit is contained in:
@@ -567,9 +567,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
# Data producer (the proper architecture for async generation)
|
# Data producer (the proper architecture for async generation)
|
||||||
self.data_producer = None
|
self.data_producer = None
|
||||||
if getattr(self.args, "use_data_producer", False):
|
if getattr(self.args, "use_data_producer", False):
|
||||||
self.data_producer = self._create_data_producer(
|
self.data_producer = self._create_data_producer(kwargs["args"], kwargs["train_dataset"])
|
||||||
kwargs["args"], kwargs["train_dataset"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.async_prefetch and self.data_producer is None:
|
if self.args.async_prefetch and self.data_producer is None:
|
||||||
# Legacy path: direct _prepare_inputs override without data producer
|
# Legacy path: direct _prepare_inputs override without data producer
|
||||||
@@ -590,17 +588,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
config=producer_config,
|
config=producer_config,
|
||||||
prompt_dataset=train_dataset,
|
prompt_dataset=train_dataset,
|
||||||
num_generations=self.num_generations,
|
num_generations=self.num_generations,
|
||||||
generation_batch_size=getattr(
|
generation_batch_size=args.generation_batch_size,
|
||||||
args,
|
|
||||||
"generation_batch_size",
|
|
||||||
self._train_batch_size * args.gradient_accumulation_steps,
|
|
||||||
),
|
|
||||||
train_batch_size=args.per_device_train_batch_size,
|
train_batch_size=args.per_device_train_batch_size,
|
||||||
steps_per_generation=args.steps_per_generation,
|
steps_per_generation=args.steps_per_generation,
|
||||||
shuffle_dataset=getattr(self, "shuffle_dataset", True),
|
shuffle_dataset=getattr(self, "shuffle_dataset", True),
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
)
|
)
|
||||||
# Inject trainer reference (needs accelerator from super().__init__)
|
|
||||||
data_producer.set_trainer(self)
|
data_producer.set_trainer(self)
|
||||||
|
|
||||||
if args.async_prefetch:
|
if args.async_prefetch:
|
||||||
@@ -868,6 +861,28 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
output[k] = forward_kwargs[k]
|
output[k] = forward_kwargs[k]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Hooks (overridden by subclasses like FastAsyncGRPOTrainer)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _compute_rewards_for_batch(self, inputs, prompts, completions, completion_ids_list):
|
||||||
|
"""Compute rewards for a batch. Override for parallel workers, caching, etc."""
|
||||||
|
return self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
|
||||||
|
|
||||||
|
def _post_advantage_hook(
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
rewards_per_func,
|
||||||
|
advantages,
|
||||||
|
inputs: list,
|
||||||
|
num_generations: int,
|
||||||
|
mode: str,
|
||||||
|
s_start: int | None = None,
|
||||||
|
s_end: int | None = None,
|
||||||
|
is_last_chunk: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Called after advantages are computed. Override for replay buffer, re-roll, etc."""
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Main-thread scoring
|
# Main-thread scoring
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -999,7 +1014,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
data["importance_sampling_ratio"] = is_ratio
|
data["importance_sampling_ratio"] = is_ratio
|
||||||
|
|
||||||
# --- Rewards ---
|
# --- Rewards ---
|
||||||
rewards_per_func = self._calculate_rewards(
|
rewards_per_func = self._compute_rewards_for_batch(
|
||||||
inputs, prompts, completions, completion_ids_list
|
inputs, prompts, completions, completion_ids_list
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1069,6 +1084,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
advantages = advantages[process_slice]
|
advantages = advantages[process_slice]
|
||||||
data["advantages"] = advantages
|
data["advantages"] = advantages
|
||||||
|
|
||||||
|
# --- Post-advantage hook (for replay buffer, re-roll, etc.) ---
|
||||||
|
self._post_advantage_hook(
|
||||||
|
data, rewards_per_func, advantages, inputs, num_generations, mode,
|
||||||
|
)
|
||||||
|
|
||||||
# --- Metrics ---
|
# --- Metrics ---
|
||||||
for i, name in enumerate(self.reward_func_names):
|
for i, name in enumerate(self.reward_func_names):
|
||||||
self._metrics[mode][f"rewards/{name}/mean"].append(
|
self._metrics[mode][f"rewards/{name}/mean"].append(
|
||||||
@@ -1337,7 +1357,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
data["ref_per_token_logps"][s_start:s_end] = ref_logps
|
data["ref_per_token_logps"][s_start:s_end] = ref_logps
|
||||||
|
|
||||||
# --- Rewards ---
|
# --- Rewards ---
|
||||||
rewards_per_func = self._calculate_rewards(
|
rewards_per_func = self._compute_rewards_for_batch(
|
||||||
inputs, prompts, completions, completion_ids_list
|
inputs, prompts, completions, completion_ids_list
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1405,6 +1425,12 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
data["advantages"] = torch.zeros(len(data["prompt_ids"]), device=device)
|
data["advantages"] = torch.zeros(len(data["prompt_ids"]), device=device)
|
||||||
data["advantages"][s_start:s_end] = advantages
|
data["advantages"][s_start:s_end] = advantages
|
||||||
|
|
||||||
|
# --- Post-advantage hook (for replay buffer, re-roll, etc.) ---
|
||||||
|
self._post_advantage_hook(
|
||||||
|
data, rewards_per_func, advantages, inputs, num_generations, mode,
|
||||||
|
s_start=s_start, s_end=s_end, is_last_chunk=is_last_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
# --- Chunk metrics ---
|
# --- Chunk metrics ---
|
||||||
for i, name in enumerate(self.reward_func_names):
|
for i, name in enumerate(self.reward_func_names):
|
||||||
self._metrics[mode][f"rewards/{name}/mean"].append(
|
self._metrics[mode][f"rewards/{name}/mean"].append(
|
||||||
|
|||||||
@@ -430,11 +430,11 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
|||||||
)
|
)
|
||||||
per_group_std = local_grouped.std(dim=1)
|
per_group_std = local_grouped.std(dim=1)
|
||||||
has_signal = (per_group_std > 0).any(dim=1)
|
has_signal = (per_group_std > 0).any(dim=1)
|
||||||
|
offset = s_start or 0
|
||||||
|
|
||||||
if has_signal.any():
|
if has_signal.any():
|
||||||
grouped_adv = advantages.view(-1, num_generations)
|
grouped_adv = advantages.view(-1, num_generations)
|
||||||
replay_scores = grouped_adv.abs().sum(dim=1) * per_group_std.sum(dim=1)
|
replay_scores = grouped_adv.abs().sum(dim=1) * per_group_std.sum(dim=1)
|
||||||
offset = s_start or 0
|
|
||||||
for group_idx in has_signal.nonzero(as_tuple=True)[0]:
|
for group_idx in has_signal.nonzero(as_tuple=True)[0]:
|
||||||
gi = group_idx.item()
|
gi = group_idx.item()
|
||||||
start = offset + gi * num_generations
|
start = offset + gi * num_generations
|
||||||
@@ -450,8 +450,8 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
|||||||
group_data[key] = val[start:end].clone()
|
group_data[key] = val[start:end].clone()
|
||||||
self._replay_buffer.add(replay_scores[gi].item(), group_data)
|
self._replay_buffer.add(replay_scores[gi].item(), group_data)
|
||||||
|
|
||||||
# Replace zero-signal groups (only in deferred path, not streaming)
|
# Replace zero-signal groups with high-signal replay buffer entries
|
||||||
if s_start is None:
|
if True:
|
||||||
no_signal = ~has_signal
|
no_signal = ~has_signal
|
||||||
n_replaced = 0
|
n_replaced = 0
|
||||||
replaced_ranges = []
|
replaced_ranges = []
|
||||||
@@ -462,7 +462,8 @@ class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
|||||||
break
|
break
|
||||||
sampled_group = sampled[0]
|
sampled_group = sampled[0]
|
||||||
gi = group_idx.item()
|
gi = group_idx.item()
|
||||||
start, end = gi * num_generations, (gi + 1) * num_generations
|
start = offset + gi * num_generations
|
||||||
|
end = start + num_generations
|
||||||
for key, val in sampled_group.items():
|
for key, val in sampled_group.items():
|
||||||
if key in data and isinstance(data[key], torch.Tensor):
|
if key in data and isinstance(data[key], torch.Tensor):
|
||||||
src = val.to(data[key].device)
|
src = val.to(data[key].device)
|
||||||
|
|||||||
Reference in New Issue
Block a user