diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index 01703c9ac..7759a8a7e 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -54,7 +54,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rl in {RLType.GRPO, RLType.GDPO}: from axolotl.core.trainers.grpo import GRPOStrategy - async_grpo = bool(self.cfg.trl and getattr(self.cfg.trl, "async_prefetch", False)) + async_grpo = bool(self.cfg.trl and ( + getattr(self.cfg.trl, "async_prefetch", False) or getattr(self.cfg.trl, "use_data_producer", False) + )) trainer_cls = GRPOStrategy.get_trainer_class( sequence_parallel=self.cfg.context_parallel_size > 1, async_grpo=async_grpo, @@ -153,7 +155,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}: from axolotl.core.trainers.grpo import GRPOStrategy - async_grpo = bool(self.cfg.trl and getattr(self.cfg.trl, "async_prefetch", False)) + async_grpo = bool(self.cfg.trl and ( + getattr(self.cfg.trl, "async_prefetch", False) or getattr(self.cfg.trl, "use_data_producer", False) + )) training_args_cls = GRPOStrategy.get_training_args_class(async_grpo=async_grpo) training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index a5ed9cbc6..4347b863a 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -137,6 +137,8 @@ class GRPOStrategy: ) # Async GRPO fields + if getattr(trl, "use_data_producer", None) is not None: + grpo_args_kwargs["use_data_producer"] = trl.use_data_producer if getattr(trl, "async_prefetch", None) is not None: grpo_args_kwargs["async_prefetch"] = trl.async_prefetch if getattr(trl, "prefetch_depth", None) is not None: diff --git a/src/axolotl/monkeypatch/trainer/async_grpo.py b/src/axolotl/monkeypatch/trainer/async_grpo.py index 5f57c18f3..1bc63b474 100644 --- a/src/axolotl/monkeypatch/trainer/async_grpo.py +++ b/src/axolotl/monkeypatch/trainer/async_grpo.py @@ -89,6 +89,12 @@ class AsyncGRPOConfig(GRPOConfig): does not define them, the defaults below ensure everything works. """ + # --- Data producer --- + use_data_producer: bool = field( + default=False, + metadata={"help": "Use the GRPODataProducer protocol for online data generation."}, + ) + # --- Async data production --- async_prefetch: bool = field( default=False, @@ -299,6 +305,195 @@ class DataProducerCallback: pass +# --------------------------------------------------------------------------- +# RolloutDataset + GRPODataProducer +# --------------------------------------------------------------------------- + + +class RolloutDataset(Dataset): + """A Dataset wrapping the output dict from _generate_and_score_completions. + + Per-sample tensors are sliced by index; shared metadata is passed through. + """ + + _ALWAYS_SHARED = frozenset({"num_items_in_batch", "_pending_policy_logps"}) + + def __init__(self, data: dict[str, Any]): + self._data = data + self._shared_keys: set[str] = set() + self._sample_keys: set[str] = set() + + for key, val in data.items(): + if key in self._ALWAYS_SHARED: + self._shared_keys.add(key) + elif not isinstance(val, torch.Tensor): + self._shared_keys.add(key) + elif val.dim() == 0: + self._shared_keys.add(key) + else: + self._sample_keys.add(key) + + self._num_samples = 0 + for key in self._sample_keys: + n = data[key].size(0) + if self._num_samples == 0: + self._num_samples = n + elif n != self._num_samples: + raise ValueError( + f"Inconsistent sample count: key '{key}' has {n}, expected {self._num_samples}" + ) + if self._num_samples == 0: + raise ValueError("No per-sample tensors found in rollout data") + + def __len__(self) -> int: + return self._num_samples + + def __getitem__(self, idx: int) -> dict[str, Any]: + item: dict[str, Any] = {} + for key in self._sample_keys: + item[key] = self._data[key][idx] + for key in self._shared_keys: + item[key] = self._data[key] + return item + + +def make_rollout_collator(shared_keys: set[str]): + """Return a collator that stacks per-sample tensors and passes shared keys through.""" + + def _collate(batch: list[dict[str, Any]]) -> dict[str, Any]: + result: dict[str, Any] = {} + for key in batch[0]: + if key in shared_keys: + result[key] = batch[0][key] + else: + values = [item[key] for item in batch] + if isinstance(values[0], torch.Tensor): + result[key] = torch.stack(values) + else: + result[key] = values + return result + + return _collate + + +class GRPODataProducer(BaseDataProducer): + """Produces GRPO training rollouts using the trainer's generation pipeline. + + Created before Trainer.__init__ completes; the trainer reference is injected + later via set_trainer(). + """ + + def __init__( + self, + config: ProducerConfig, + prompt_dataset, + *, + num_generations: int, + generation_batch_size: int, + train_batch_size: int, + steps_per_generation: int, + shuffle_dataset: bool, + seed: int, + ): + super().__init__(config) + self._dataset = prompt_dataset + self._num_generations = num_generations + self._generation_batch_size = generation_batch_size + self._train_batch_size = train_batch_size + self._steps_per_generation = steps_per_generation + self._shuffle_dataset = shuffle_dataset + self._seed = seed + self._trainer = None + self._prompt_dl: DataLoader | None = None + self._prompt_iter = None + + def set_trainer(self, trainer) -> None: + """Inject the live trainer reference and create the prompt DataLoader.""" + self._trainer = trainer + self._init_prompt_dataloader() + + def _init_prompt_dataloader(self) -> None: + from functools import partial + from transformers.trainer_utils import seed_worker + + trainer = self._trainer + sampler = RepeatSampler( + data_source=self._dataset, + mini_repeat_count=self._num_generations, + batch_size=self._generation_batch_size // self._num_generations, + repeat_count=1, + shuffle=self._shuffle_dataset, + seed=self._seed, + ) + + # Use identity collator (same as stock GRPOTrainer) + def _identity(x): + return x + + dl = DataLoader( + self._dataset, + batch_size=self._train_batch_size * self._steps_per_generation, + sampler=sampler, + collate_fn=_identity, + num_workers=trainer.args.dataloader_num_workers, + pin_memory=trainer.args.dataloader_pin_memory, + persistent_workers=trainer.args.dataloader_persistent_workers, + worker_init_fn=partial( + seed_worker, + num_workers=trainer.args.dataloader_num_workers, + rank=trainer.args.process_index, + ), + ) + self._prompt_dl = trainer.accelerator.prepare(dl) + + # Don't let accelerator track this dataloader + acc_dls = trainer.accelerator._dataloaders + if self._prompt_dl in acc_dls: + acc_dls.remove(self._prompt_dl) + + self._prompt_iter = iter(self._prompt_dl) + + def produce( + self, + model: Any, + global_step: int, + *, + skip_policy_logps: bool = False, + processing_class: Any = None, + accelerator: Any = None, + args: Any = None, + **kwargs, + ) -> RolloutDataset: + """Generate a fresh GRPO training rollout.""" + try: + inputs = next(self._prompt_iter) + except StopIteration: + self._prompt_iter = iter(self._prompt_dl) + inputs = next(self._prompt_iter) + + if skip_policy_logps: + # Async path: use _generate_only (generation without scoring) which + # works on stock TRL (no skip_policy_logps parameter needed). + output = self._trainer._generate_only(inputs) + else: + # Sync path: full generation + scoring + output = self._trainer._generate_and_score_completions(inputs) + + # Strip non-sequence metadata before shuffling + metadata = {} + for key in list(output.keys()): + val = output[key] + if not isinstance(val, (torch.Tensor, list)): + metadata[key] = output.pop(key) + elif isinstance(val, torch.Tensor) and val.dim() == 0: + metadata[key] = output.pop(key) + + output = shuffle_sequence_dict(output) + output.update(metadata) + + return RolloutDataset(output) + + # --------------------------------------------------------------------------- # Trainer # --------------------------------------------------------------------------- @@ -331,9 +526,50 @@ class AsyncGRPOTrainer(GRPOTrainer): self._buffered_inputs: list | None = None # override stock attr self._current_train_step_time = 0.0 - if self.args.async_prefetch: + # Data producer (the proper architecture for async generation) + self.data_producer = None + if getattr(self.args, "use_data_producer", False): + self.data_producer = self._create_data_producer() + + if self.args.async_prefetch and self.data_producer is None: + # Legacy path: direct _prepare_inputs override without data producer self._setup_async() + def _create_data_producer(self): + """Create and return the GRPODataProducer (possibly wrapped in AsyncDataProducer).""" + args = self.args + producer_config = ProducerConfig( + mini_epochs=args.num_iterations, + max_rollouts=None, + eval_during_produce=False, + empty_cache_before_produce=True, + empty_cache_after_produce=True, + async_prefetch=args.async_prefetch, + prefetch_depth=args.prefetch_depth, + ) + data_producer = GRPODataProducer( + config=producer_config, + prompt_dataset=self.train_dataset, + num_generations=self.num_generations, + generation_batch_size=getattr( + args, "generation_batch_size", + self._train_batch_size * args.gradient_accumulation_steps, + ), + train_batch_size=args.per_device_train_batch_size, + steps_per_generation=args.steps_per_generation, + shuffle_dataset=getattr(self, "shuffle_dataset", True), + seed=args.seed, + ) + # Inject trainer reference (needs accelerator from super().__init__) + data_producer.set_trainer(self) + + if args.async_prefetch: + data_producer = AsyncDataProducer( + data_producer, + background_produce_kwargs={"skip_policy_logps": True}, + ) + return data_producer + # ------------------------------------------------------------------ # Async setup / teardown # ------------------------------------------------------------------ @@ -1056,11 +1292,61 @@ class AsyncGRPOTrainer(GRPOTrainer): # ------------------------------------------------------------------ def _prepare_inputs(self, generation_batch): - """Override to support async prefetch with optional streaming scoring.""" + """Override to support data producer and async prefetch paths.""" mode = "train" if self.model.training else "eval" - if mode != "train" or not self.args.async_prefetch: - return super()._prepare_inputs(generation_batch) + # --- Data producer path --- + if mode == "train" and self.data_producer is not None: + return self._prepare_inputs_data_producer(generation_batch) + + # --- Legacy async prefetch path (no data producer) --- + if mode == "train" and self.args.async_prefetch: + return self._prepare_inputs_legacy_async(generation_batch) + + # --- Stock path --- + return super()._prepare_inputs(generation_batch) + + def _prepare_inputs_data_producer(self, generation_batch): + """Data producer path: produce rollout, score deferred logps, split into micro-batches.""" + # Return from buffer if available + if self._buffered_inputs: + return self._buffered_inputs.pop(0) + + # Produce a new rollout + self._maybe_sync_vllm_weights() + rollout_dataset = self.data_producer.produce( + self.model, self.state.global_step, + processing_class=self.processing_class, + accelerator=self.accelerator, + args=self.args, + ) + + # Convert RolloutDataset back to a dict for scoring/splitting + rollout = rollout_dataset._data + + # If async (skip_policy_logps=True), score deferred logps on main thread + if rollout.get("_pending_policy_logps"): + if self.args.streaming_partial_batch: + micro_batches = self._score_streaming(rollout) + else: + scored = self._compute_deferred_scores(rollout) + scored = split_pixel_values_by_grid(scored) + scored = shuffle_sequence_dict(scored) + batches = split_tensor_dict(scored, self.args.steps_per_generation) + micro_batches = [unsplit_pixel_values_by_grid(b) for b in batches] + micro_batches = micro_batches * self.num_iterations + else: + # Sync path: data is already fully scored + rollout = split_pixel_values_by_grid(rollout) + batches = split_tensor_dict(rollout, self.args.steps_per_generation) + micro_batches = [unsplit_pixel_values_by_grid(b) for b in batches] + micro_batches = micro_batches * self.num_iterations + + self._buffered_inputs = micro_batches[1:] + return micro_batches[0] + + def _prepare_inputs_legacy_async(self, generation_batch): + """Legacy async path: direct queue-based prefetch without data producer.""" # Return from buffer if available if self._buffered_inputs: return self._buffered_inputs.pop(0) diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index f316b4637..99cf6019d 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -191,6 +191,12 @@ class TRLConfig(BaseModel): ) # Async GRPO fields + use_data_producer: bool = Field( + default=False, + json_schema_extra={ + "description": "Use the GRPODataProducer protocol for online data generation." + }, + ) async_prefetch: bool = Field( default=False, json_schema_extra={