more fixes

This commit is contained in:
Wing Lian
2023-08-25 22:46:28 -04:00
parent 8b16ecd448
commit a6c9223114
2 changed files with 6 additions and 5 deletions

View File

@@ -142,7 +142,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
bench_dataset = load_dataset(
"pharaouk/dharma-1",
data_files={
"eval": "dharma_1_full.json",
"eval": "dharma_1_mini.json",
},
)
bench_dataset["eval"] = bench_dataset["eval"].map(transform_subject)
@@ -218,7 +218,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
**kwargs, # pylint: disable=unused-argument
):
data_loader = trainer.get_bench_dataloader(
bench_dataset.remove_columns(["input", "subject", "output"])
bench_dataset.remove_columns(["input", "subject", "output", "name"])
)
trainer.model.eval()
preds, refs = [], []
@@ -242,7 +242,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
loss_bench += loss.item()
# Extract results by subject.
results = {"bench_loss": loss_bench / len(data_loader)}
bench_name = bench_dataset["subject"]
bench_name = bench_dataset["name"]
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
bench_names[s]["preds"].append(p)

View File

@@ -137,7 +137,7 @@ class AxolotlTrainingArguments(TrainingArguments):
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="mmlu-zs",
default="sampled",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, `sampled`"
},
@@ -279,7 +279,8 @@ class AxolotlTrainer(Trainer):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc