more fixes
This commit is contained in:
@@ -22,7 +22,7 @@ from transformers import (
|
|||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.distributed import barrier, is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.utils.trainer import AxolotlTrainingArguments
|
from axolotl.utils.trainer import AxolotlTrainingArguments
|
||||||
@@ -203,53 +203,47 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
metrics: Dict[str, float], # pylint: disable=unused-argument
|
metrics: Dict[str, float], # pylint: disable=unused-argument
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
if is_main_process():
|
data_loader = trainer.get_eval_dataloader(bench_dataset)
|
||||||
data_loader = trainer.get_eval_dataloader(bench_dataset)
|
source_max_len = trainer.data_collator.max_length
|
||||||
source_max_len = trainer.data_collator.max_length
|
trainer.data_collator.max_length = args.bench_source_max_len
|
||||||
trainer.data_collator.max_length = args.bench_source_max_len
|
trainer.model.eval()
|
||||||
trainer.model.eval()
|
preds, refs = [], []
|
||||||
preds, refs = [], []
|
loss_bench = 0
|
||||||
loss_bench = 0
|
for batch in tqdm(data_loader, total=len(data_loader)):
|
||||||
for batch in tqdm(data_loader, total=len(data_loader)):
|
(loss, logits, labels) = trainer.prediction_step(
|
||||||
(loss, logits, labels) = trainer.prediction_step(
|
trainer.model,
|
||||||
trainer.model,
|
batch,
|
||||||
batch,
|
prediction_loss_only=False,
|
||||||
prediction_loss_only=False,
|
)
|
||||||
)
|
# There are two tokens, the output, and eos token.
|
||||||
# There are two tokens, the output, and eos token.
|
for i, logit in enumerate(logits):
|
||||||
for i, logit in enumerate(logits):
|
label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0]
|
||||||
label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0]
|
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
|
||||||
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
|
preds.append(torch.argmax(logit_abcd).item())
|
||||||
preds.append(torch.argmax(logit_abcd).item())
|
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
||||||
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
refs += [
|
||||||
refs += [
|
abcd_idx.index(label) if labels in abcd_idx else -1
|
||||||
abcd_idx.index(label) if labels in abcd_idx else -1
|
for label in labels.tolist()
|
||||||
for label in labels.tolist()
|
]
|
||||||
]
|
loss_bench += loss.item()
|
||||||
loss_bench += loss.item()
|
# Extract results by subject.
|
||||||
# Extract results by subject.
|
results = {"bench_loss": loss_bench / len(data_loader)}
|
||||||
results = {"bench_loss": loss_bench / len(data_loader)}
|
subject = bench_dataset["subject"]
|
||||||
subject = bench_dataset["subject"]
|
subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)}
|
||||||
subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)}
|
for s, p, r in zip(subject, preds, refs): # pylint: disable=invalid-name
|
||||||
for s, p, r in zip( # pylint: disable=invalid-name
|
subjects[s]["preds"].append(p)
|
||||||
subject, preds, refs
|
subjects[s]["refs"].append(r)
|
||||||
):
|
subject_scores = []
|
||||||
subjects[s]["preds"].append(p)
|
for subject in subjects:
|
||||||
subjects[s]["refs"].append(r)
|
subject_score = accuracy.compute(
|
||||||
subject_scores = []
|
references=subjects[subject]["refs"],
|
||||||
for subject in subjects:
|
predictions=subjects[subject]["preds"],
|
||||||
subject_score = accuracy.compute(
|
)["accuracy"]
|
||||||
references=subjects[subject]["refs"],
|
if not pd.isna(subject_score):
|
||||||
predictions=subjects[subject]["preds"],
|
results[f"bench_{bench_split}_accuracy_{subject}"] = subject_score
|
||||||
)["accuracy"]
|
subject_scores.append(subject_score)
|
||||||
if not pd.isna(subject_score):
|
results[f"bench_{bench_split}_accuracy"] = np.mean(subject_scores)
|
||||||
results[
|
trainer.log(results)
|
||||||
f"bench_{bench_split}_accuracy_{subject}"
|
trainer.data_collator.max_length = source_max_len
|
||||||
] = subject_score
|
|
||||||
subject_scores.append(subject_score)
|
|
||||||
results[f"bench_{bench_split}_accuracy"] = np.mean(subject_scores)
|
|
||||||
trainer.log(results)
|
|
||||||
trainer.data_collator.max_length = source_max_len
|
|
||||||
barrier()
|
|
||||||
|
|
||||||
return BenchEvalCallback
|
return BenchEvalCallback
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
bench_source_max_len: int = field(
|
bench_source_max_len: int = field(
|
||||||
default=2048, metadata={"help": "Maximum source sequence length for mmlu."}
|
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -540,9 +540,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.do_bench_eval:
|
if cfg.do_bench_eval:
|
||||||
training_arguments_kwargs["do_mmlu_eval"] = cfg.do_bench_eval
|
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
|
||||||
if cfg.bench_dataset:
|
if cfg.bench_dataset:
|
||||||
training_arguments_kwargs["mmlu_dataset"] = cfg.bench_dataset
|
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||||
|
|||||||
Reference in New Issue
Block a user