diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index e9d992262..f19232625 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -109,7 +109,6 @@ class MultipackDistributedDataloader: seq_max_length: int = 2048, batch_size: int = 1, sampler: Union[Sampler, DistributedSampler] = None, - seed: int = 0, ): # Dataset self.dataset = dataset @@ -127,19 +126,10 @@ class MultipackDistributedDataloader: self.num_replicas = 1 self.rank = 0 - # Seed - self.seed = seed - - # Epoch - self.epoch = 0 - # statistics self.eff_total_used = 0 self.eff_total_slots = 0 - def set_epoch(self, epoch: int): - self.epoch = epoch - def generate_batches(self, set_stats=False): if self.sampler: indices = [idx for idx in self.sampler] diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 3a6e8d298..cc68edb02 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -14,7 +14,7 @@ import transformers from datasets import Dataset from torch import nn from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_pt_utils import get_parameter_names @@ -87,18 +87,26 @@ class AxolotlTrainer(Trainer): return super().create_scheduler(num_training_steps, optimizer) return self.lr_scheduler + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.args.world_size > 1 and self.args.sample_packing: + return DistributedSampler( + self.train_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, + seed=self.args.seed, + ) + return super()._get_train_sampler() + def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]: if self.args.sample_packing: train_sampler = self._get_train_sampler() - return self.accelerator.prepare( - MultipackDistributedDataloader( - self.train_dataset, - batch_size=self._train_batch_size, - seq_max_length=self.args.max_seq_length, - collate_fn=self.data_collator, - sampler=train_sampler, - ) + return MultipackDistributedDataloader( + self.train_dataset, + batch_size=self._train_batch_size, + seq_max_length=self.args.max_seq_length, + collate_fn=self.data_collator, + sampler=train_sampler, ) return super().get_train_dataloader() @@ -278,7 +286,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg - max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps + # max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None