more fixes

This commit is contained in:
Wing Lian
2023-08-21 04:58:54 -04:00
parent 918e040601
commit 2455254b92
2 changed files with 46 additions and 52 deletions

View File

@@ -22,7 +22,7 @@ from transformers import (
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import barrier, is_main_process, zero_first from axolotl.utils.distributed import is_main_process, zero_first
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.utils.trainer import AxolotlTrainingArguments from axolotl.utils.trainer import AxolotlTrainingArguments
@@ -203,53 +203,47 @@ def bench_eval_callback_factory(trainer, tokenizer):
metrics: Dict[str, float], # pylint: disable=unused-argument metrics: Dict[str, float], # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
): ):
if is_main_process(): data_loader = trainer.get_eval_dataloader(bench_dataset)
data_loader = trainer.get_eval_dataloader(bench_dataset) source_max_len = trainer.data_collator.max_length
source_max_len = trainer.data_collator.max_length trainer.data_collator.max_length = args.bench_source_max_len
trainer.data_collator.max_length = args.bench_source_max_len trainer.model.eval()
trainer.model.eval() preds, refs = [], []
preds, refs = [], [] loss_bench = 0
loss_bench = 0 for batch in tqdm(data_loader, total=len(data_loader)):
for batch in tqdm(data_loader, total=len(data_loader)): (loss, logits, labels) = trainer.prediction_step(
(loss, logits, labels) = trainer.prediction_step( trainer.model,
trainer.model, batch,
batch, prediction_loss_only=False,
prediction_loss_only=False, )
) # There are two tokens, the output, and eos token.
# There are two tokens, the output, and eos token. for i, logit in enumerate(logits):
for i, logit in enumerate(logits): label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0]
label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0] logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
logit_abcd = logit[label_non_zero_id - 1][abcd_idx] preds.append(torch.argmax(logit_abcd).item())
preds.append(torch.argmax(logit_abcd).item()) labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0] refs += [
refs += [ abcd_idx.index(label) if labels in abcd_idx else -1
abcd_idx.index(label) if labels in abcd_idx else -1 for label in labels.tolist()
for label in labels.tolist() ]
] loss_bench += loss.item()
loss_bench += loss.item() # Extract results by subject.
# Extract results by subject. results = {"bench_loss": loss_bench / len(data_loader)}
results = {"bench_loss": loss_bench / len(data_loader)} subject = bench_dataset["subject"]
subject = bench_dataset["subject"] subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)}
subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)} for s, p, r in zip(subject, preds, refs): # pylint: disable=invalid-name
for s, p, r in zip( # pylint: disable=invalid-name subjects[s]["preds"].append(p)
subject, preds, refs subjects[s]["refs"].append(r)
): subject_scores = []
subjects[s]["preds"].append(p) for subject in subjects:
subjects[s]["refs"].append(r) subject_score = accuracy.compute(
subject_scores = [] references=subjects[subject]["refs"],
for subject in subjects: predictions=subjects[subject]["preds"],
subject_score = accuracy.compute( )["accuracy"]
references=subjects[subject]["refs"], if not pd.isna(subject_score):
predictions=subjects[subject]["preds"], results[f"bench_{bench_split}_accuracy_{subject}"] = subject_score
)["accuracy"] subject_scores.append(subject_score)
if not pd.isna(subject_score): results[f"bench_{bench_split}_accuracy"] = np.mean(subject_scores)
results[ trainer.log(results)
f"bench_{bench_split}_accuracy_{subject}" trainer.data_collator.max_length = source_max_len
] = subject_score
subject_scores.append(subject_score)
results[f"bench_{bench_split}_accuracy"] = np.mean(subject_scores)
trainer.log(results)
trainer.data_collator.max_length = source_max_len
barrier()
return BenchEvalCallback return BenchEvalCallback

View File

@@ -147,7 +147,7 @@ class AxolotlTrainingArguments(TrainingArguments):
}, },
) )
bench_source_max_len: int = field( bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for mmlu."} default=2048, metadata={"help": "Maximum source sequence length for bench."}
) )
@@ -540,9 +540,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
) )
if cfg.do_bench_eval: if cfg.do_bench_eval:
training_arguments_kwargs["do_mmlu_eval"] = cfg.do_bench_eval training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
if cfg.bench_dataset: if cfg.bench_dataset:
training_arguments_kwargs["mmlu_dataset"] = cfg.bench_dataset training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1, max_steps=total_num_steps if cfg.max_steps else -1,