diff --git a/src/axolotl/core/trainers/grpo/async_trainer.py b/src/axolotl/core/trainers/grpo/async_trainer.py index 169e7de9f..3af8c4a28 100644 --- a/src/axolotl/core/trainers/grpo/async_trainer.py +++ b/src/axolotl/core/trainers/grpo/async_trainer.py @@ -567,15 +567,16 @@ class AsyncGRPOTrainer(GRPOTrainer): # Data producer (the proper architecture for async generation) self.data_producer = None if getattr(self.args, "use_data_producer", False): - self.data_producer = self._create_data_producer() + self.data_producer = self._create_data_producer( + kwargs["args"], kwargs["train_dataset"] + ) if self.args.async_prefetch and self.data_producer is None: # Legacy path: direct _prepare_inputs override without data producer self._setup_async() - def _create_data_producer(self): + def _create_data_producer(self, args, train_dataset): """Create and return the GRPODataProducer (possibly wrapped in AsyncDataProducer).""" - args = self.args producer_config = ProducerConfig( mini_epochs=args.num_iterations, max_rollouts=None, @@ -587,7 +588,7 @@ class AsyncGRPOTrainer(GRPOTrainer): ) data_producer = GRPODataProducer( config=producer_config, - prompt_dataset=self.train_dataset, + prompt_dataset=train_dataset, num_generations=self.num_generations, generation_batch_size=getattr( args, diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index ac377360a..42da60d5c 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -282,7 +282,7 @@ class TRLConfig(BaseModel): }, ) reroll_start_fraction: float = Field( - default=0.5, + default=1.0, json_schema_extra={ "description": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts " "(where all rewards in a group are identical) are buffered and re-injected into later batches when the "