handle call to create data producer
This commit is contained in:
@@ -567,15 +567,16 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
# Data producer (the proper architecture for async generation)
|
||||
self.data_producer = None
|
||||
if getattr(self.args, "use_data_producer", False):
|
||||
self.data_producer = self._create_data_producer()
|
||||
self.data_producer = self._create_data_producer(
|
||||
kwargs["args"], kwargs["train_dataset"]
|
||||
)
|
||||
|
||||
if self.args.async_prefetch and self.data_producer is None:
|
||||
# Legacy path: direct _prepare_inputs override without data producer
|
||||
self._setup_async()
|
||||
|
||||
def _create_data_producer(self):
|
||||
def _create_data_producer(self, args, train_dataset):
|
||||
"""Create and return the GRPODataProducer (possibly wrapped in AsyncDataProducer)."""
|
||||
args = self.args
|
||||
producer_config = ProducerConfig(
|
||||
mini_epochs=args.num_iterations,
|
||||
max_rollouts=None,
|
||||
@@ -587,7 +588,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
)
|
||||
data_producer = GRPODataProducer(
|
||||
config=producer_config,
|
||||
prompt_dataset=self.train_dataset,
|
||||
prompt_dataset=train_dataset,
|
||||
num_generations=self.num_generations,
|
||||
generation_batch_size=getattr(
|
||||
args,
|
||||
|
||||
@@ -282,7 +282,7 @@ class TRLConfig(BaseModel):
|
||||
},
|
||||
)
|
||||
reroll_start_fraction: float = Field(
|
||||
default=0.5,
|
||||
default=1.0,
|
||||
json_schema_extra={
|
||||
"description": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
|
||||
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "
|
||||
|
||||
Reference in New Issue
Block a user