From 99d844f2150d68b02ecbfee88a86786ad754d79f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 25 Aug 2023 20:38:43 -0400 Subject: [PATCH] benchmark callback has its own dataloader and collator --- src/axolotl/utils/callbacks.py | 5 +---- src/axolotl/utils/trainer.py | 40 +++++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 9ea949a7c..18896d5cd 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -217,9 +217,7 @@ def bench_eval_callback_factory(trainer, tokenizer): metrics: Dict[str, float], # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument ): - data_loader = trainer.get_eval_dataloader(bench_dataset) - source_max_len = trainer.data_collator.max_length - trainer.data_collator.max_length = args.bench_source_max_len + data_loader = trainer.get_bench_dataloader(bench_dataset) trainer.model.eval() preds, refs = [], [] loss_bench = 0 @@ -258,6 +256,5 @@ def bench_eval_callback_factory(trainer, tokenizer): bench_scores.append(bench_score) results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores) trainer.log(results) - trainer.data_collator.max_length = source_max_len return BenchEvalCallback diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 27dba92ef..a7933dc35 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -14,7 +14,12 @@ import numpy as np import torch.cuda from datasets import Dataset, set_caching_enabled from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import DataLoader, DistributedSampler, RandomSampler +from torch.utils.data import ( + DataLoader, + DistributedSampler, + RandomSampler, + SequentialSampler, +) from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_pt_utils import SequentialDistributedSampler @@ -158,6 +163,10 @@ class AxolotlTrainer(Trainer): args = None # type: AxolotlTrainingArguments + def __init__(self, *args, bench_data_collator=None, **kwargs): + self.bench_data_collator = bench_data_collator + super().__init__(*args, **kwargs) + def create_scheduler( self, num_training_steps: int, optimizer: torch.optim.Optimizer = None ): @@ -248,6 +257,30 @@ class AxolotlTrainer(Trainer): ) return super().get_eval_dataloader(eval_dataset) + def _get_bench_sampler( + self, bench_dataset: Dataset + ) -> Optional[torch.utils.data.Sampler]: + if self.args.world_size <= 1: + return SequentialSampler(bench_dataset) + return None + + def get_bench_dataloader( + self, + bench_dataset: Dataset, + ) -> Union[DataLoader, MultipackDistributedDataloader]: + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": self.bench_data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + + if not isinstance(bench_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + + return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) + def compute_loss(self, model, inputs, return_outputs=False): # use one's weighted cross entropy loss calc # if self.args.sample_packing: @@ -654,6 +687,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ return_tensors="pt", **data_collator_kwargs, ), + bench_data_collat0r=transformers.DataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + **data_collator_kwargs, + ), callbacks=callbacks, **trainer_kwargs, )