diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index fa3c907df..7901f3f62 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -142,7 +142,7 @@ def bench_eval_callback_factory(trainer, tokenizer): bench_dataset = load_dataset( "pharaouk/dharma-1", data_files={ - "eval": "dharma_1_full.json", + "eval": "dharma_1_mini.json", }, ) bench_dataset["eval"] = bench_dataset["eval"].map(transform_subject) @@ -218,7 +218,7 @@ def bench_eval_callback_factory(trainer, tokenizer): **kwargs, # pylint: disable=unused-argument ): data_loader = trainer.get_bench_dataloader( - bench_dataset.remove_columns(["input", "subject", "output"]) + bench_dataset.remove_columns(["input", "subject", "output", "name"]) ) trainer.model.eval() preds, refs = [], [] @@ -242,7 +242,7 @@ def bench_eval_callback_factory(trainer, tokenizer): loss_bench += loss.item() # Extract results by subject. results = {"bench_loss": loss_bench / len(data_loader)} - bench_name = bench_dataset["subject"] + bench_name = bench_dataset["name"] bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)} for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name bench_names[s]["preds"].append(p) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index a9f4dec26..ddc22b454 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -137,7 +137,7 @@ class AxolotlTrainingArguments(TrainingArguments): default="eval", metadata={"help": "The benchmark split to run on"} ) bench_dataset: Optional[str] = field( - default="mmlu-zs", + default="sampled", metadata={ "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, `sampled`" }, @@ -279,7 +279,8 @@ class AxolotlTrainer(Trainer): dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) dataloader_params["drop_last"] = self.args.dataloader_drop_last - return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) + return DataLoader(bench_dataset, **dataloader_params) + # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) def compute_loss(self, model, inputs, return_outputs=False): # use one's weighted cross entropy loss calc