diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index dccc85d80..c97fccd31 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -3,6 +3,7 @@ # pylint: disable=too-many-lines,duplicate-code,protected-access,no-member import warnings +from functools import partial from typing import Any import datasets @@ -58,6 +59,42 @@ class AxolotlGRPOTrainer( _tag_names = ["trl", "grpo", "axolotl"] + def get_train_dataloader(self): + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns( + train_dataset, description="training" + ) + else: + data_collator = self._get_collator_with_removed_columns( + data_collator, description="training" + ) + + dataloader_params = { + "batch_size": self._train_batch_size + * self.args.steps_per_generation, # < this is the change + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = partial( + seed_worker, + num_workers=self.args.dataloader_num_workers, + rank=self.args.process_index, + ) + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): """Extend the base GRPOTrainer for sequence parallelism handling"""