enable eval dataloader using multipack again

This commit is contained in:
Wing Lian
2023-08-17 18:07:02 -04:00
parent 7565fb9d63
commit 6c306d9186

View File

@@ -14,7 +14,7 @@ import bitsandbytes as bnb
import numpy as np import numpy as np
import torch.cuda import torch.cuda
import transformers import transformers
from datasets import set_caching_enabled from datasets import Dataset, set_caching_enabled
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
@@ -188,27 +188,27 @@ class AxolotlTrainer(Trainer):
) )
return super().get_train_dataloader() return super().get_train_dataloader()
# def get_eval_dataloader( def get_eval_dataloader(
# self, eval_dataset: Optional[Dataset] = None self, eval_dataset: Optional[Dataset] = None
# ) -> Union[DataLoader, MultipackDistributedDataloader]: ) -> Union[DataLoader, MultipackDistributedDataloader]:
# if self.args.sample_packing: if self.args.sample_packing:
# eval_dataset = ( eval_dataset = (
# eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset if eval_dataset is not None else self.eval_dataset
# ) )
# eval_sampler = self._get_eval_sampler(eval_dataset) eval_sampler = self._get_eval_sampler(eval_dataset)
# return self.accelerator.prepare( return self.accelerator.prepare(
# MultipackDistributedDataloader( MultipackDistributedDataloader(
# eval_dataset, eval_dataset,
# batch_size=self.args.eval_batch_size, batch_size=self.args.eval_batch_size,
# seq_max_length=self.args.max_seq_length, seq_max_length=self.args.max_seq_length,
# collate_fn=self.data_collator, collate_fn=self.data_collator,
# sampler=eval_sampler, sampler=eval_sampler,
# packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
# sample_packing_seq_len_multiplier=self.args.eval_batch_size, sample_packing_seq_len_multiplier=self.args.eval_batch_size,
# device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
# ) )
# ) )
# return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
def compute_loss(self, model, inputs, return_outputs=False): def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc # use one's weighted cross entropy loss calc