fix worker_init_fn signature handling (#2769)
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -58,6 +59,42 @@ class AxolotlGRPOTrainer(
|
|||||||
|
|
||||||
_tag_names = ["trl", "grpo", "axolotl"]
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
|
def get_train_dataloader(self):
|
||||||
|
if self.train_dataset is None:
|
||||||
|
raise ValueError("Trainer: training requires a train_dataset.")
|
||||||
|
|
||||||
|
train_dataset = self.train_dataset
|
||||||
|
data_collator = self.data_collator
|
||||||
|
if isinstance(train_dataset, datasets.Dataset):
|
||||||
|
train_dataset = self._remove_unused_columns(
|
||||||
|
train_dataset, description="training"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data_collator = self._get_collator_with_removed_columns(
|
||||||
|
data_collator, description="training"
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader_params = {
|
||||||
|
"batch_size": self._train_batch_size
|
||||||
|
* self.args.steps_per_generation, # < this is the change
|
||||||
|
"collate_fn": data_collator,
|
||||||
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
|
"persistent_workers": self.args.dataloader_persistent_workers,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||||
|
dataloader_params["sampler"] = self._get_train_sampler()
|
||||||
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
dataloader_params["worker_init_fn"] = partial(
|
||||||
|
seed_worker,
|
||||||
|
num_workers=self.args.dataloader_num_workers,
|
||||||
|
rank=self.args.process_index,
|
||||||
|
)
|
||||||
|
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
|
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||||
|
|||||||
Reference in New Issue
Block a user