handle call to create data producer

This commit is contained in:
Wing Lian
2026-03-09 19:59:09 -04:00
parent d69d52ba41
commit e380f6944d
2 changed files with 6 additions and 5 deletions

View File

@@ -567,15 +567,16 @@ class AsyncGRPOTrainer(GRPOTrainer):
# Data producer (the proper architecture for async generation)
self.data_producer = None
if getattr(self.args, "use_data_producer", False):
self.data_producer = self._create_data_producer()
self.data_producer = self._create_data_producer(
kwargs["args"], kwargs["train_dataset"]
)
if self.args.async_prefetch and self.data_producer is None:
# Legacy path: direct _prepare_inputs override without data producer
self._setup_async()
def _create_data_producer(self):
def _create_data_producer(self, args, train_dataset):
"""Create and return the GRPODataProducer (possibly wrapped in AsyncDataProducer)."""
args = self.args
producer_config = ProducerConfig(
mini_epochs=args.num_iterations,
max_rollouts=None,
@@ -587,7 +588,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
)
data_producer = GRPODataProducer(
config=producer_config,
prompt_dataset=self.train_dataset,
prompt_dataset=train_dataset,
num_generations=self.num_generations,
generation_batch_size=getattr(
args,

View File

@@ -282,7 +282,7 @@ class TRLConfig(BaseModel):
},
)
reroll_start_fraction: float = Field(
default=0.5,
default=1.0,
json_schema_extra={
"description": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "