From f733d0f31e545f177f310d14a6366bab729dc16e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 19 Aug 2023 10:35:04 -0400 Subject: [PATCH] disable eval using multipack for now (#437) --- src/axolotl/utils/trainer.py | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 7369d7c6d..5f24e13c0 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -14,7 +14,7 @@ import bitsandbytes as bnb import numpy as np import torch.cuda import transformers -from datasets import Dataset, set_caching_enabled +from datasets import set_caching_enabled from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import DataLoader, DistributedSampler, RandomSampler @@ -188,27 +188,27 @@ class AxolotlTrainer(Trainer): ) return super().get_train_dataloader() - def get_eval_dataloader( - self, eval_dataset: Optional[Dataset] = None - ) -> Union[DataLoader, MultipackDistributedDataloader]: - if self.args.sample_packing: - eval_dataset = ( - eval_dataset if eval_dataset is not None else self.eval_dataset - ) - eval_sampler = self._get_eval_sampler(eval_dataset) - return self.accelerator.prepare( - MultipackDistributedDataloader( - eval_dataset, - batch_size=self.args.eval_batch_size, - seq_max_length=self.args.max_seq_length, - collate_fn=self.data_collator, - sampler=eval_sampler, - packing_efficiency_estimate=self.args.sample_packing_efficiency, - sample_packing_seq_len_multiplier=self.args.eval_batch_size, - device_count=int(os.environ.get("WORLD_SIZE", 1)), - ) - ) - return super().get_eval_dataloader(eval_dataset) + # def get_eval_dataloader( + # self, eval_dataset: Optional[Dataset] = None + # ) -> Union[DataLoader, MultipackDistributedDataloader]: + # if self.args.sample_packing: + # eval_dataset = ( + # eval_dataset if eval_dataset is not None else self.eval_dataset + # ) + # eval_sampler = self._get_eval_sampler(eval_dataset) + # return self.accelerator.prepare( + # MultipackDistributedDataloader( + # eval_dataset, + # batch_size=self.args.eval_batch_size, + # seq_max_length=self.args.max_seq_length, + # collate_fn=self.data_collator, + # sampler=eval_sampler, + # packing_efficiency_estimate=self.args.sample_packing_efficiency, + # sample_packing_seq_len_multiplier=self.args.eval_batch_size, + # device_count=int(os.environ.get("WORLD_SIZE", 1)), + # ) + # ) + # return super().get_eval_dataloader(eval_dataset) def compute_loss(self, model, inputs, return_outputs=False): # use one's weighted cross entropy loss calc