handle call to create data producer
This commit is contained in:
@@ -567,15 +567,16 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
# Data producer (the proper architecture for async generation)
|
# Data producer (the proper architecture for async generation)
|
||||||
self.data_producer = None
|
self.data_producer = None
|
||||||
if getattr(self.args, "use_data_producer", False):
|
if getattr(self.args, "use_data_producer", False):
|
||||||
self.data_producer = self._create_data_producer()
|
self.data_producer = self._create_data_producer(
|
||||||
|
kwargs["args"], kwargs["train_dataset"]
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.async_prefetch and self.data_producer is None:
|
if self.args.async_prefetch and self.data_producer is None:
|
||||||
# Legacy path: direct _prepare_inputs override without data producer
|
# Legacy path: direct _prepare_inputs override without data producer
|
||||||
self._setup_async()
|
self._setup_async()
|
||||||
|
|
||||||
def _create_data_producer(self):
|
def _create_data_producer(self, args, train_dataset):
|
||||||
"""Create and return the GRPODataProducer (possibly wrapped in AsyncDataProducer)."""
|
"""Create and return the GRPODataProducer (possibly wrapped in AsyncDataProducer)."""
|
||||||
args = self.args
|
|
||||||
producer_config = ProducerConfig(
|
producer_config = ProducerConfig(
|
||||||
mini_epochs=args.num_iterations,
|
mini_epochs=args.num_iterations,
|
||||||
max_rollouts=None,
|
max_rollouts=None,
|
||||||
@@ -587,7 +588,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
)
|
)
|
||||||
data_producer = GRPODataProducer(
|
data_producer = GRPODataProducer(
|
||||||
config=producer_config,
|
config=producer_config,
|
||||||
prompt_dataset=self.train_dataset,
|
prompt_dataset=train_dataset,
|
||||||
num_generations=self.num_generations,
|
num_generations=self.num_generations,
|
||||||
generation_batch_size=getattr(
|
generation_batch_size=getattr(
|
||||||
args,
|
args,
|
||||||
|
|||||||
@@ -282,7 +282,7 @@ class TRLConfig(BaseModel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
reroll_start_fraction: float = Field(
|
reroll_start_fraction: float = Field(
|
||||||
default=0.5,
|
default=1.0,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
|
"description": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
|
||||||
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "
|
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "
|
||||||
|
|||||||
Reference in New Issue
Block a user