|
|
|
|
@@ -89,6 +89,12 @@ class AsyncGRPOConfig(GRPOConfig):
|
|
|
|
|
does not define them, the defaults below ensure everything works.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# --- Data producer ---
|
|
|
|
|
use_data_producer: bool = field(
|
|
|
|
|
default=False,
|
|
|
|
|
metadata={"help": "Use the GRPODataProducer protocol for online data generation."},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# --- Async data production ---
|
|
|
|
|
async_prefetch: bool = field(
|
|
|
|
|
default=False,
|
|
|
|
|
@@ -299,6 +305,195 @@ class DataProducerCallback:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# RolloutDataset + GRPODataProducer
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RolloutDataset(Dataset):
|
|
|
|
|
"""A Dataset wrapping the output dict from _generate_and_score_completions.
|
|
|
|
|
|
|
|
|
|
Per-sample tensors are sliced by index; shared metadata is passed through.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
_ALWAYS_SHARED = frozenset({"num_items_in_batch", "_pending_policy_logps"})
|
|
|
|
|
|
|
|
|
|
def __init__(self, data: dict[str, Any]):
|
|
|
|
|
self._data = data
|
|
|
|
|
self._shared_keys: set[str] = set()
|
|
|
|
|
self._sample_keys: set[str] = set()
|
|
|
|
|
|
|
|
|
|
for key, val in data.items():
|
|
|
|
|
if key in self._ALWAYS_SHARED:
|
|
|
|
|
self._shared_keys.add(key)
|
|
|
|
|
elif not isinstance(val, torch.Tensor):
|
|
|
|
|
self._shared_keys.add(key)
|
|
|
|
|
elif val.dim() == 0:
|
|
|
|
|
self._shared_keys.add(key)
|
|
|
|
|
else:
|
|
|
|
|
self._sample_keys.add(key)
|
|
|
|
|
|
|
|
|
|
self._num_samples = 0
|
|
|
|
|
for key in self._sample_keys:
|
|
|
|
|
n = data[key].size(0)
|
|
|
|
|
if self._num_samples == 0:
|
|
|
|
|
self._num_samples = n
|
|
|
|
|
elif n != self._num_samples:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Inconsistent sample count: key '{key}' has {n}, expected {self._num_samples}"
|
|
|
|
|
)
|
|
|
|
|
if self._num_samples == 0:
|
|
|
|
|
raise ValueError("No per-sample tensors found in rollout data")
|
|
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
|
return self._num_samples
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, idx: int) -> dict[str, Any]:
|
|
|
|
|
item: dict[str, Any] = {}
|
|
|
|
|
for key in self._sample_keys:
|
|
|
|
|
item[key] = self._data[key][idx]
|
|
|
|
|
for key in self._shared_keys:
|
|
|
|
|
item[key] = self._data[key]
|
|
|
|
|
return item
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_rollout_collator(shared_keys: set[str]):
|
|
|
|
|
"""Return a collator that stacks per-sample tensors and passes shared keys through."""
|
|
|
|
|
|
|
|
|
|
def _collate(batch: list[dict[str, Any]]) -> dict[str, Any]:
|
|
|
|
|
result: dict[str, Any] = {}
|
|
|
|
|
for key in batch[0]:
|
|
|
|
|
if key in shared_keys:
|
|
|
|
|
result[key] = batch[0][key]
|
|
|
|
|
else:
|
|
|
|
|
values = [item[key] for item in batch]
|
|
|
|
|
if isinstance(values[0], torch.Tensor):
|
|
|
|
|
result[key] = torch.stack(values)
|
|
|
|
|
else:
|
|
|
|
|
result[key] = values
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
return _collate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GRPODataProducer(BaseDataProducer):
|
|
|
|
|
"""Produces GRPO training rollouts using the trainer's generation pipeline.
|
|
|
|
|
|
|
|
|
|
Created before Trainer.__init__ completes; the trainer reference is injected
|
|
|
|
|
later via set_trainer().
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
config: ProducerConfig,
|
|
|
|
|
prompt_dataset,
|
|
|
|
|
*,
|
|
|
|
|
num_generations: int,
|
|
|
|
|
generation_batch_size: int,
|
|
|
|
|
train_batch_size: int,
|
|
|
|
|
steps_per_generation: int,
|
|
|
|
|
shuffle_dataset: bool,
|
|
|
|
|
seed: int,
|
|
|
|
|
):
|
|
|
|
|
super().__init__(config)
|
|
|
|
|
self._dataset = prompt_dataset
|
|
|
|
|
self._num_generations = num_generations
|
|
|
|
|
self._generation_batch_size = generation_batch_size
|
|
|
|
|
self._train_batch_size = train_batch_size
|
|
|
|
|
self._steps_per_generation = steps_per_generation
|
|
|
|
|
self._shuffle_dataset = shuffle_dataset
|
|
|
|
|
self._seed = seed
|
|
|
|
|
self._trainer = None
|
|
|
|
|
self._prompt_dl: DataLoader | None = None
|
|
|
|
|
self._prompt_iter = None
|
|
|
|
|
|
|
|
|
|
def set_trainer(self, trainer) -> None:
|
|
|
|
|
"""Inject the live trainer reference and create the prompt DataLoader."""
|
|
|
|
|
self._trainer = trainer
|
|
|
|
|
self._init_prompt_dataloader()
|
|
|
|
|
|
|
|
|
|
def _init_prompt_dataloader(self) -> None:
|
|
|
|
|
from functools import partial
|
|
|
|
|
from transformers.trainer_utils import seed_worker
|
|
|
|
|
|
|
|
|
|
trainer = self._trainer
|
|
|
|
|
sampler = RepeatSampler(
|
|
|
|
|
data_source=self._dataset,
|
|
|
|
|
mini_repeat_count=self._num_generations,
|
|
|
|
|
batch_size=self._generation_batch_size // self._num_generations,
|
|
|
|
|
repeat_count=1,
|
|
|
|
|
shuffle=self._shuffle_dataset,
|
|
|
|
|
seed=self._seed,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Use identity collator (same as stock GRPOTrainer)
|
|
|
|
|
def _identity(x):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
dl = DataLoader(
|
|
|
|
|
self._dataset,
|
|
|
|
|
batch_size=self._train_batch_size * self._steps_per_generation,
|
|
|
|
|
sampler=sampler,
|
|
|
|
|
collate_fn=_identity,
|
|
|
|
|
num_workers=trainer.args.dataloader_num_workers,
|
|
|
|
|
pin_memory=trainer.args.dataloader_pin_memory,
|
|
|
|
|
persistent_workers=trainer.args.dataloader_persistent_workers,
|
|
|
|
|
worker_init_fn=partial(
|
|
|
|
|
seed_worker,
|
|
|
|
|
num_workers=trainer.args.dataloader_num_workers,
|
|
|
|
|
rank=trainer.args.process_index,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
self._prompt_dl = trainer.accelerator.prepare(dl)
|
|
|
|
|
|
|
|
|
|
# Don't let accelerator track this dataloader
|
|
|
|
|
acc_dls = trainer.accelerator._dataloaders
|
|
|
|
|
if self._prompt_dl in acc_dls:
|
|
|
|
|
acc_dls.remove(self._prompt_dl)
|
|
|
|
|
|
|
|
|
|
self._prompt_iter = iter(self._prompt_dl)
|
|
|
|
|
|
|
|
|
|
def produce(
|
|
|
|
|
self,
|
|
|
|
|
model: Any,
|
|
|
|
|
global_step: int,
|
|
|
|
|
*,
|
|
|
|
|
skip_policy_logps: bool = False,
|
|
|
|
|
processing_class: Any = None,
|
|
|
|
|
accelerator: Any = None,
|
|
|
|
|
args: Any = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> RolloutDataset:
|
|
|
|
|
"""Generate a fresh GRPO training rollout."""
|
|
|
|
|
try:
|
|
|
|
|
inputs = next(self._prompt_iter)
|
|
|
|
|
except StopIteration:
|
|
|
|
|
self._prompt_iter = iter(self._prompt_dl)
|
|
|
|
|
inputs = next(self._prompt_iter)
|
|
|
|
|
|
|
|
|
|
if skip_policy_logps:
|
|
|
|
|
# Async path: use _generate_only (generation without scoring) which
|
|
|
|
|
# works on stock TRL (no skip_policy_logps parameter needed).
|
|
|
|
|
output = self._trainer._generate_only(inputs)
|
|
|
|
|
else:
|
|
|
|
|
# Sync path: full generation + scoring
|
|
|
|
|
output = self._trainer._generate_and_score_completions(inputs)
|
|
|
|
|
|
|
|
|
|
# Strip non-sequence metadata before shuffling
|
|
|
|
|
metadata = {}
|
|
|
|
|
for key in list(output.keys()):
|
|
|
|
|
val = output[key]
|
|
|
|
|
if not isinstance(val, (torch.Tensor, list)):
|
|
|
|
|
metadata[key] = output.pop(key)
|
|
|
|
|
elif isinstance(val, torch.Tensor) and val.dim() == 0:
|
|
|
|
|
metadata[key] = output.pop(key)
|
|
|
|
|
|
|
|
|
|
output = shuffle_sequence_dict(output)
|
|
|
|
|
output.update(metadata)
|
|
|
|
|
|
|
|
|
|
return RolloutDataset(output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# Trainer
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
@@ -331,9 +526,50 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|
|
|
|
self._buffered_inputs: list | None = None # override stock attr
|
|
|
|
|
self._current_train_step_time = 0.0
|
|
|
|
|
|
|
|
|
|
if self.args.async_prefetch:
|
|
|
|
|
# Data producer (the proper architecture for async generation)
|
|
|
|
|
self.data_producer = None
|
|
|
|
|
if getattr(self.args, "use_data_producer", False):
|
|
|
|
|
self.data_producer = self._create_data_producer()
|
|
|
|
|
|
|
|
|
|
if self.args.async_prefetch and self.data_producer is None:
|
|
|
|
|
# Legacy path: direct _prepare_inputs override without data producer
|
|
|
|
|
self._setup_async()
|
|
|
|
|
|
|
|
|
|
def _create_data_producer(self):
|
|
|
|
|
"""Create and return the GRPODataProducer (possibly wrapped in AsyncDataProducer)."""
|
|
|
|
|
args = self.args
|
|
|
|
|
producer_config = ProducerConfig(
|
|
|
|
|
mini_epochs=args.num_iterations,
|
|
|
|
|
max_rollouts=None,
|
|
|
|
|
eval_during_produce=False,
|
|
|
|
|
empty_cache_before_produce=True,
|
|
|
|
|
empty_cache_after_produce=True,
|
|
|
|
|
async_prefetch=args.async_prefetch,
|
|
|
|
|
prefetch_depth=args.prefetch_depth,
|
|
|
|
|
)
|
|
|
|
|
data_producer = GRPODataProducer(
|
|
|
|
|
config=producer_config,
|
|
|
|
|
prompt_dataset=self.train_dataset,
|
|
|
|
|
num_generations=self.num_generations,
|
|
|
|
|
generation_batch_size=getattr(
|
|
|
|
|
args, "generation_batch_size",
|
|
|
|
|
self._train_batch_size * args.gradient_accumulation_steps,
|
|
|
|
|
),
|
|
|
|
|
train_batch_size=args.per_device_train_batch_size,
|
|
|
|
|
steps_per_generation=args.steps_per_generation,
|
|
|
|
|
shuffle_dataset=getattr(self, "shuffle_dataset", True),
|
|
|
|
|
seed=args.seed,
|
|
|
|
|
)
|
|
|
|
|
# Inject trainer reference (needs accelerator from super().__init__)
|
|
|
|
|
data_producer.set_trainer(self)
|
|
|
|
|
|
|
|
|
|
if args.async_prefetch:
|
|
|
|
|
data_producer = AsyncDataProducer(
|
|
|
|
|
data_producer,
|
|
|
|
|
background_produce_kwargs={"skip_policy_logps": True},
|
|
|
|
|
)
|
|
|
|
|
return data_producer
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
# Async setup / teardown
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
@@ -1056,11 +1292,61 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
def _prepare_inputs(self, generation_batch):
|
|
|
|
|
"""Override to support async prefetch with optional streaming scoring."""
|
|
|
|
|
"""Override to support data producer and async prefetch paths."""
|
|
|
|
|
mode = "train" if self.model.training else "eval"
|
|
|
|
|
if mode != "train" or not self.args.async_prefetch:
|
|
|
|
|
return super()._prepare_inputs(generation_batch)
|
|
|
|
|
|
|
|
|
|
# --- Data producer path ---
|
|
|
|
|
if mode == "train" and self.data_producer is not None:
|
|
|
|
|
return self._prepare_inputs_data_producer(generation_batch)
|
|
|
|
|
|
|
|
|
|
# --- Legacy async prefetch path (no data producer) ---
|
|
|
|
|
if mode == "train" and self.args.async_prefetch:
|
|
|
|
|
return self._prepare_inputs_legacy_async(generation_batch)
|
|
|
|
|
|
|
|
|
|
# --- Stock path ---
|
|
|
|
|
return super()._prepare_inputs(generation_batch)
|
|
|
|
|
|
|
|
|
|
def _prepare_inputs_data_producer(self, generation_batch):
|
|
|
|
|
"""Data producer path: produce rollout, score deferred logps, split into micro-batches."""
|
|
|
|
|
# Return from buffer if available
|
|
|
|
|
if self._buffered_inputs:
|
|
|
|
|
return self._buffered_inputs.pop(0)
|
|
|
|
|
|
|
|
|
|
# Produce a new rollout
|
|
|
|
|
self._maybe_sync_vllm_weights()
|
|
|
|
|
rollout_dataset = self.data_producer.produce(
|
|
|
|
|
self.model, self.state.global_step,
|
|
|
|
|
processing_class=self.processing_class,
|
|
|
|
|
accelerator=self.accelerator,
|
|
|
|
|
args=self.args,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Convert RolloutDataset back to a dict for scoring/splitting
|
|
|
|
|
rollout = rollout_dataset._data
|
|
|
|
|
|
|
|
|
|
# If async (skip_policy_logps=True), score deferred logps on main thread
|
|
|
|
|
if rollout.get("_pending_policy_logps"):
|
|
|
|
|
if self.args.streaming_partial_batch:
|
|
|
|
|
micro_batches = self._score_streaming(rollout)
|
|
|
|
|
else:
|
|
|
|
|
scored = self._compute_deferred_scores(rollout)
|
|
|
|
|
scored = split_pixel_values_by_grid(scored)
|
|
|
|
|
scored = shuffle_sequence_dict(scored)
|
|
|
|
|
batches = split_tensor_dict(scored, self.args.steps_per_generation)
|
|
|
|
|
micro_batches = [unsplit_pixel_values_by_grid(b) for b in batches]
|
|
|
|
|
micro_batches = micro_batches * self.num_iterations
|
|
|
|
|
else:
|
|
|
|
|
# Sync path: data is already fully scored
|
|
|
|
|
rollout = split_pixel_values_by_grid(rollout)
|
|
|
|
|
batches = split_tensor_dict(rollout, self.args.steps_per_generation)
|
|
|
|
|
micro_batches = [unsplit_pixel_values_by_grid(b) for b in batches]
|
|
|
|
|
micro_batches = micro_batches * self.num_iterations
|
|
|
|
|
|
|
|
|
|
self._buffered_inputs = micro_batches[1:]
|
|
|
|
|
return micro_batches[0]
|
|
|
|
|
|
|
|
|
|
def _prepare_inputs_legacy_async(self, generation_batch):
|
|
|
|
|
"""Legacy async path: direct queue-based prefetch without data producer."""
|
|
|
|
|
# Return from buffer if available
|
|
|
|
|
if self._buffered_inputs:
|
|
|
|
|
return self._buffered_inputs.pop(0)
|
|
|
|
|
|