benchmark callback has its own dataloader and collator

This commit is contained in:
Wing Lian
2023-08-25 20:38:43 -04:00
parent aefd4d74fa
commit 99d844f215
2 changed files with 40 additions and 5 deletions

View File

@@ -217,9 +217,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
metrics: Dict[str, float], # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
data_loader = trainer.get_eval_dataloader(bench_dataset)
source_max_len = trainer.data_collator.max_length
trainer.data_collator.max_length = args.bench_source_max_len
data_loader = trainer.get_bench_dataloader(bench_dataset)
trainer.model.eval()
preds, refs = [], []
loss_bench = 0
@@ -258,6 +256,5 @@ def bench_eval_callback_factory(trainer, tokenizer):
bench_scores.append(bench_score)
results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores)
trainer.log(results)
trainer.data_collator.max_length = source_max_len
return BenchEvalCallback

View File

@@ -14,7 +14,12 @@ import numpy as np
import torch.cuda
from datasets import Dataset, set_caching_enabled
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from torch.utils.data import (
DataLoader,
DistributedSampler,
RandomSampler,
SequentialSampler,
)
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import SequentialDistributedSampler
@@ -158,6 +163,10 @@ class AxolotlTrainer(Trainer):
args = None # type: AxolotlTrainingArguments
def __init__(self, *args, bench_data_collator=None, **kwargs):
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
@@ -248,6 +257,30 @@ class AxolotlTrainer(Trainer):
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> Union[DataLoader, MultipackDistributedDataloader]:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
@@ -654,6 +687,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collat0r=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=callbacks,
**trainer_kwargs,
)