disable eval using multipack for now (#437)

This commit is contained in:
Wing Lian
2023-08-19 10:35:04 -04:00
committed by GitHub
parent 008505c8ae
commit f733d0f31e

View File

@@ -14,7 +14,7 @@ import bitsandbytes as bnb
import numpy as np import numpy as np
import torch.cuda import torch.cuda
import transformers import transformers
from datasets import Dataset, set_caching_enabled from datasets import set_caching_enabled
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
@@ -188,27 +188,27 @@ class AxolotlTrainer(Trainer):
) )
return super().get_train_dataloader() return super().get_train_dataloader()
def get_eval_dataloader( # def get_eval_dataloader(
self, eval_dataset: Optional[Dataset] = None # self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]: # ) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing: # if self.args.sample_packing:
eval_dataset = ( # eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset # eval_dataset if eval_dataset is not None else self.eval_dataset
) # )
eval_sampler = self._get_eval_sampler(eval_dataset) # eval_sampler = self._get_eval_sampler(eval_dataset)
return self.accelerator.prepare( # return self.accelerator.prepare(
MultipackDistributedDataloader( # MultipackDistributedDataloader(
eval_dataset, # eval_dataset,
batch_size=self.args.eval_batch_size, # batch_size=self.args.eval_batch_size,
seq_max_length=self.args.max_seq_length, # seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator, # collate_fn=self.data_collator,
sampler=eval_sampler, # sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency, # packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size, # sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)), # device_count=int(os.environ.get("WORLD_SIZE", 1)),
) # )
) # )
return super().get_eval_dataloader(eval_dataset) # return super().get_eval_dataloader(eval_dataset)
def compute_loss(self, model, inputs, return_outputs=False): def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc # use one's weighted cross entropy loss calc