more fixes
This commit is contained in:
@@ -142,7 +142,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
||||
bench_dataset = load_dataset(
|
||||
"pharaouk/dharma-1",
|
||||
data_files={
|
||||
"eval": "dharma_1_full.json",
|
||||
"eval": "dharma_1_mini.json",
|
||||
},
|
||||
)
|
||||
bench_dataset["eval"] = bench_dataset["eval"].map(transform_subject)
|
||||
@@ -218,7 +218,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
data_loader = trainer.get_bench_dataloader(
|
||||
bench_dataset.remove_columns(["input", "subject", "output"])
|
||||
bench_dataset.remove_columns(["input", "subject", "output", "name"])
|
||||
)
|
||||
trainer.model.eval()
|
||||
preds, refs = [], []
|
||||
@@ -242,7 +242,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
||||
loss_bench += loss.item()
|
||||
# Extract results by subject.
|
||||
results = {"bench_loss": loss_bench / len(data_loader)}
|
||||
bench_name = bench_dataset["subject"]
|
||||
bench_name = bench_dataset["name"]
|
||||
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
|
||||
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
|
||||
bench_names[s]["preds"].append(p)
|
||||
|
||||
@@ -137,7 +137,7 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="mmlu-zs",
|
||||
default="sampled",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, `sampled`"
|
||||
},
|
||||
@@ -279,7 +279,8 @@ class AxolotlTrainer(Trainer):
|
||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# use one's weighted cross entropy loss calc
|
||||
|
||||
Reference in New Issue
Block a user