implement data producer

This commit is contained in:
Wing Lian
2026-03-09 23:28:42 +00:00
parent f0c9e98699
commit 575425a36f
4 changed files with 304 additions and 6 deletions

View File

@@ -54,7 +54,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
async_grpo = bool(self.cfg.trl and getattr(self.cfg.trl, "async_prefetch", False))
async_grpo = bool(self.cfg.trl and (
getattr(self.cfg.trl, "async_prefetch", False) or getattr(self.cfg.trl, "use_data_producer", False)
))
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1,
async_grpo=async_grpo,
@@ -153,7 +155,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
async_grpo = bool(self.cfg.trl and getattr(self.cfg.trl, "async_prefetch", False))
async_grpo = bool(self.cfg.trl and (
getattr(self.cfg.trl, "async_prefetch", False) or getattr(self.cfg.trl, "use_data_producer", False)
))
training_args_cls = GRPOStrategy.get_training_args_class(async_grpo=async_grpo)
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()

View File

@@ -137,6 +137,8 @@ class GRPOStrategy:
)
# Async GRPO fields
if getattr(trl, "use_data_producer", None) is not None:
grpo_args_kwargs["use_data_producer"] = trl.use_data_producer
if getattr(trl, "async_prefetch", None) is not None:
grpo_args_kwargs["async_prefetch"] = trl.async_prefetch
if getattr(trl, "prefetch_depth", None) is not None:

View File

@@ -89,6 +89,12 @@ class AsyncGRPOConfig(GRPOConfig):
does not define them, the defaults below ensure everything works.
"""
# --- Data producer ---
use_data_producer: bool = field(
default=False,
metadata={"help": "Use the GRPODataProducer protocol for online data generation."},
)
# --- Async data production ---
async_prefetch: bool = field(
default=False,
@@ -299,6 +305,195 @@ class DataProducerCallback:
pass
# ---------------------------------------------------------------------------
# RolloutDataset + GRPODataProducer
# ---------------------------------------------------------------------------
class RolloutDataset(Dataset):
"""A Dataset wrapping the output dict from _generate_and_score_completions.
Per-sample tensors are sliced by index; shared metadata is passed through.
"""
_ALWAYS_SHARED = frozenset({"num_items_in_batch", "_pending_policy_logps"})
def __init__(self, data: dict[str, Any]):
self._data = data
self._shared_keys: set[str] = set()
self._sample_keys: set[str] = set()
for key, val in data.items():
if key in self._ALWAYS_SHARED:
self._shared_keys.add(key)
elif not isinstance(val, torch.Tensor):
self._shared_keys.add(key)
elif val.dim() == 0:
self._shared_keys.add(key)
else:
self._sample_keys.add(key)
self._num_samples = 0
for key in self._sample_keys:
n = data[key].size(0)
if self._num_samples == 0:
self._num_samples = n
elif n != self._num_samples:
raise ValueError(
f"Inconsistent sample count: key '{key}' has {n}, expected {self._num_samples}"
)
if self._num_samples == 0:
raise ValueError("No per-sample tensors found in rollout data")
def __len__(self) -> int:
return self._num_samples
def __getitem__(self, idx: int) -> dict[str, Any]:
item: dict[str, Any] = {}
for key in self._sample_keys:
item[key] = self._data[key][idx]
for key in self._shared_keys:
item[key] = self._data[key]
return item
def make_rollout_collator(shared_keys: set[str]):
"""Return a collator that stacks per-sample tensors and passes shared keys through."""
def _collate(batch: list[dict[str, Any]]) -> dict[str, Any]:
result: dict[str, Any] = {}
for key in batch[0]:
if key in shared_keys:
result[key] = batch[0][key]
else:
values = [item[key] for item in batch]
if isinstance(values[0], torch.Tensor):
result[key] = torch.stack(values)
else:
result[key] = values
return result
return _collate
class GRPODataProducer(BaseDataProducer):
"""Produces GRPO training rollouts using the trainer's generation pipeline.
Created before Trainer.__init__ completes; the trainer reference is injected
later via set_trainer().
"""
def __init__(
self,
config: ProducerConfig,
prompt_dataset,
*,
num_generations: int,
generation_batch_size: int,
train_batch_size: int,
steps_per_generation: int,
shuffle_dataset: bool,
seed: int,
):
super().__init__(config)
self._dataset = prompt_dataset
self._num_generations = num_generations
self._generation_batch_size = generation_batch_size
self._train_batch_size = train_batch_size
self._steps_per_generation = steps_per_generation
self._shuffle_dataset = shuffle_dataset
self._seed = seed
self._trainer = None
self._prompt_dl: DataLoader | None = None
self._prompt_iter = None
def set_trainer(self, trainer) -> None:
"""Inject the live trainer reference and create the prompt DataLoader."""
self._trainer = trainer
self._init_prompt_dataloader()
def _init_prompt_dataloader(self) -> None:
from functools import partial
from transformers.trainer_utils import seed_worker
trainer = self._trainer
sampler = RepeatSampler(
data_source=self._dataset,
mini_repeat_count=self._num_generations,
batch_size=self._generation_batch_size // self._num_generations,
repeat_count=1,
shuffle=self._shuffle_dataset,
seed=self._seed,
)
# Use identity collator (same as stock GRPOTrainer)
def _identity(x):
return x
dl = DataLoader(
self._dataset,
batch_size=self._train_batch_size * self._steps_per_generation,
sampler=sampler,
collate_fn=_identity,
num_workers=trainer.args.dataloader_num_workers,
pin_memory=trainer.args.dataloader_pin_memory,
persistent_workers=trainer.args.dataloader_persistent_workers,
worker_init_fn=partial(
seed_worker,
num_workers=trainer.args.dataloader_num_workers,
rank=trainer.args.process_index,
),
)
self._prompt_dl = trainer.accelerator.prepare(dl)
# Don't let accelerator track this dataloader
acc_dls = trainer.accelerator._dataloaders
if self._prompt_dl in acc_dls:
acc_dls.remove(self._prompt_dl)
self._prompt_iter = iter(self._prompt_dl)
def produce(
self,
model: Any,
global_step: int,
*,
skip_policy_logps: bool = False,
processing_class: Any = None,
accelerator: Any = None,
args: Any = None,
**kwargs,
) -> RolloutDataset:
"""Generate a fresh GRPO training rollout."""
try:
inputs = next(self._prompt_iter)
except StopIteration:
self._prompt_iter = iter(self._prompt_dl)
inputs = next(self._prompt_iter)
if skip_policy_logps:
# Async path: use _generate_only (generation without scoring) which
# works on stock TRL (no skip_policy_logps parameter needed).
output = self._trainer._generate_only(inputs)
else:
# Sync path: full generation + scoring
output = self._trainer._generate_and_score_completions(inputs)
# Strip non-sequence metadata before shuffling
metadata = {}
for key in list(output.keys()):
val = output[key]
if not isinstance(val, (torch.Tensor, list)):
metadata[key] = output.pop(key)
elif isinstance(val, torch.Tensor) and val.dim() == 0:
metadata[key] = output.pop(key)
output = shuffle_sequence_dict(output)
output.update(metadata)
return RolloutDataset(output)
# ---------------------------------------------------------------------------
# Trainer
# ---------------------------------------------------------------------------
@@ -331,9 +526,50 @@ class AsyncGRPOTrainer(GRPOTrainer):
self._buffered_inputs: list | None = None # override stock attr
self._current_train_step_time = 0.0
if self.args.async_prefetch:
# Data producer (the proper architecture for async generation)
self.data_producer = None
if getattr(self.args, "use_data_producer", False):
self.data_producer = self._create_data_producer()
if self.args.async_prefetch and self.data_producer is None:
# Legacy path: direct _prepare_inputs override without data producer
self._setup_async()
def _create_data_producer(self):
"""Create and return the GRPODataProducer (possibly wrapped in AsyncDataProducer)."""
args = self.args
producer_config = ProducerConfig(
mini_epochs=args.num_iterations,
max_rollouts=None,
eval_during_produce=False,
empty_cache_before_produce=True,
empty_cache_after_produce=True,
async_prefetch=args.async_prefetch,
prefetch_depth=args.prefetch_depth,
)
data_producer = GRPODataProducer(
config=producer_config,
prompt_dataset=self.train_dataset,
num_generations=self.num_generations,
generation_batch_size=getattr(
args, "generation_batch_size",
self._train_batch_size * args.gradient_accumulation_steps,
),
train_batch_size=args.per_device_train_batch_size,
steps_per_generation=args.steps_per_generation,
shuffle_dataset=getattr(self, "shuffle_dataset", True),
seed=args.seed,
)
# Inject trainer reference (needs accelerator from super().__init__)
data_producer.set_trainer(self)
if args.async_prefetch:
data_producer = AsyncDataProducer(
data_producer,
background_produce_kwargs={"skip_policy_logps": True},
)
return data_producer
# ------------------------------------------------------------------
# Async setup / teardown
# ------------------------------------------------------------------
@@ -1056,11 +1292,61 @@ class AsyncGRPOTrainer(GRPOTrainer):
# ------------------------------------------------------------------
def _prepare_inputs(self, generation_batch):
"""Override to support async prefetch with optional streaming scoring."""
"""Override to support data producer and async prefetch paths."""
mode = "train" if self.model.training else "eval"
if mode != "train" or not self.args.async_prefetch:
return super()._prepare_inputs(generation_batch)
# --- Data producer path ---
if mode == "train" and self.data_producer is not None:
return self._prepare_inputs_data_producer(generation_batch)
# --- Legacy async prefetch path (no data producer) ---
if mode == "train" and self.args.async_prefetch:
return self._prepare_inputs_legacy_async(generation_batch)
# --- Stock path ---
return super()._prepare_inputs(generation_batch)
def _prepare_inputs_data_producer(self, generation_batch):
"""Data producer path: produce rollout, score deferred logps, split into micro-batches."""
# Return from buffer if available
if self._buffered_inputs:
return self._buffered_inputs.pop(0)
# Produce a new rollout
self._maybe_sync_vllm_weights()
rollout_dataset = self.data_producer.produce(
self.model, self.state.global_step,
processing_class=self.processing_class,
accelerator=self.accelerator,
args=self.args,
)
# Convert RolloutDataset back to a dict for scoring/splitting
rollout = rollout_dataset._data
# If async (skip_policy_logps=True), score deferred logps on main thread
if rollout.get("_pending_policy_logps"):
if self.args.streaming_partial_batch:
micro_batches = self._score_streaming(rollout)
else:
scored = self._compute_deferred_scores(rollout)
scored = split_pixel_values_by_grid(scored)
scored = shuffle_sequence_dict(scored)
batches = split_tensor_dict(scored, self.args.steps_per_generation)
micro_batches = [unsplit_pixel_values_by_grid(b) for b in batches]
micro_batches = micro_batches * self.num_iterations
else:
# Sync path: data is already fully scored
rollout = split_pixel_values_by_grid(rollout)
batches = split_tensor_dict(rollout, self.args.steps_per_generation)
micro_batches = [unsplit_pixel_values_by_grid(b) for b in batches]
micro_batches = micro_batches * self.num_iterations
self._buffered_inputs = micro_batches[1:]
return micro_batches[0]
def _prepare_inputs_legacy_async(self, generation_batch):
"""Legacy async path: direct queue-based prefetch without data producer."""
# Return from buffer if available
if self._buffered_inputs:
return self._buffered_inputs.pop(0)

View File

@@ -191,6 +191,12 @@ class TRLConfig(BaseModel):
)
# Async GRPO fields
use_data_producer: bool = Field(
default=False,
json_schema_extra={
"description": "Use the GRPODataProducer protocol for online data generation."
},
)
async_prefetch: bool = Field(
default=False,
json_schema_extra={