more fixes

This commit is contained in:
Wing Lian
2023-08-25 22:46:28 -04:00
parent 8b16ecd448
commit a6c9223114
2 changed files with 6 additions and 5 deletions

View File

@@ -142,7 +142,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
bench_dataset = load_dataset( bench_dataset = load_dataset(
"pharaouk/dharma-1", "pharaouk/dharma-1",
data_files={ data_files={
"eval": "dharma_1_full.json", "eval": "dharma_1_mini.json",
}, },
) )
bench_dataset["eval"] = bench_dataset["eval"].map(transform_subject) bench_dataset["eval"] = bench_dataset["eval"].map(transform_subject)
@@ -218,7 +218,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
): ):
data_loader = trainer.get_bench_dataloader( data_loader = trainer.get_bench_dataloader(
bench_dataset.remove_columns(["input", "subject", "output"]) bench_dataset.remove_columns(["input", "subject", "output", "name"])
) )
trainer.model.eval() trainer.model.eval()
preds, refs = [], [] preds, refs = [], []
@@ -242,7 +242,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
loss_bench += loss.item() loss_bench += loss.item()
# Extract results by subject. # Extract results by subject.
results = {"bench_loss": loss_bench / len(data_loader)} results = {"bench_loss": loss_bench / len(data_loader)}
bench_name = bench_dataset["subject"] bench_name = bench_dataset["name"]
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)} bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
bench_names[s]["preds"].append(p) bench_names[s]["preds"].append(p)

View File

@@ -137,7 +137,7 @@ class AxolotlTrainingArguments(TrainingArguments):
default="eval", metadata={"help": "The benchmark split to run on"} default="eval", metadata={"help": "The benchmark split to run on"}
) )
bench_dataset: Optional[str] = field( bench_dataset: Optional[str] = field(
default="mmlu-zs", default="sampled",
metadata={ metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, `sampled`" "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, `sampled`"
}, },
@@ -279,7 +279,8 @@ class AxolotlTrainer(Trainer):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["drop_last"] = self.args.dataloader_drop_last
return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False): def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc # use one's weighted cross entropy loss calc