more fixes
This commit is contained in:
@@ -142,7 +142,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
bench_dataset = load_dataset(
|
bench_dataset = load_dataset(
|
||||||
"pharaouk/dharma-1",
|
"pharaouk/dharma-1",
|
||||||
data_files={
|
data_files={
|
||||||
"eval": "dharma_1_full.json",
|
"eval": "dharma_1_mini.json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
bench_dataset["eval"] = bench_dataset["eval"].map(transform_subject)
|
bench_dataset["eval"] = bench_dataset["eval"].map(transform_subject)
|
||||||
@@ -218,7 +218,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
data_loader = trainer.get_bench_dataloader(
|
data_loader = trainer.get_bench_dataloader(
|
||||||
bench_dataset.remove_columns(["input", "subject", "output"])
|
bench_dataset.remove_columns(["input", "subject", "output", "name"])
|
||||||
)
|
)
|
||||||
trainer.model.eval()
|
trainer.model.eval()
|
||||||
preds, refs = [], []
|
preds, refs = [], []
|
||||||
@@ -242,7 +242,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
loss_bench += loss.item()
|
loss_bench += loss.item()
|
||||||
# Extract results by subject.
|
# Extract results by subject.
|
||||||
results = {"bench_loss": loss_bench / len(data_loader)}
|
results = {"bench_loss": loss_bench / len(data_loader)}
|
||||||
bench_name = bench_dataset["subject"]
|
bench_name = bench_dataset["name"]
|
||||||
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
|
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
|
||||||
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
|
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
|
||||||
bench_names[s]["preds"].append(p)
|
bench_names[s]["preds"].append(p)
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||||
)
|
)
|
||||||
bench_dataset: Optional[str] = field(
|
bench_dataset: Optional[str] = field(
|
||||||
default="mmlu-zs",
|
default="sampled",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, `sampled`"
|
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, `sampled`"
|
||||||
},
|
},
|
||||||
@@ -279,7 +279,8 @@ class AxolotlTrainer(Trainer):
|
|||||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
return DataLoader(bench_dataset, **dataloader_params)
|
||||||
|
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
# use one's weighted cross entropy loss calc
|
# use one's weighted cross entropy loss calc
|
||||||
|
|||||||
Reference in New Issue
Block a user