diff --git a/requirements.txt b/requirements.txt index 0ae20f300..fcd7f9292 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ transformers @ git+https://github.com/huggingface/transformers.git bitsandbytes>=0.41.1 accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b addict +evaluate fire PyYAML>=6.0 datasets diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index ddc179f39..92333f4ca 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -1,9 +1,19 @@ """Callbacks for Trainer class""" +from __future__ import annotations + import logging import os +from typing import TYPE_CHECKING, Dict, List +import evaluate +import numpy as np +import pandas as pd +import torch +import torch.distributed as dist +from datasets import load_dataset from optimum.bettertransformer import BetterTransformer +from tqdm import tqdm from transformers import ( TrainerCallback, TrainerControl, @@ -13,8 +23,19 @@ from transformers import ( from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.distributed import ( + barrier, + gather_scalar_from_all_ranks, + get_world_size, + is_main_process, + zero_first, +) + +if TYPE_CHECKING: + from axolotl.utils.trainer import AxolotlTrainingArguments LOG = logging.getLogger("axolotl.callbacks") +IGNORE_INDEX = -100 class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods @@ -96,3 +117,192 @@ class GPUStatsCallback( log_gpu_memory_usage(LOG, "while training", self.cfg.device) self.logged = True return control + + +def bench_eval_callback_factory(trainer, tokenizer): + accuracy = evaluate.load("accuracy") + abcd_idx = [ + tokenizer("A", add_special_tokens=False).input_ids[0], + tokenizer("B", add_special_tokens=False).input_ids[0], + tokenizer("C", add_special_tokens=False).input_ids[0], + tokenizer("D", add_special_tokens=False).input_ids[0], + tokenizer("E", add_special_tokens=False).input_ids[0], + tokenizer("F", add_special_tokens=False).input_ids[0], + tokenizer("G", add_special_tokens=False).input_ids[0], + ] + bench_split = "eval" + + def transform_bench_subject(example): + # Split on ':' and trim whitespace + parts = example["subject"].split(":") + first_part = ( + parts[0].strip().lower().replace("-", "_") + ) # Lowercase the first part + second_part = ( + parts[1].strip().replace("-", "_") if len(parts) > 1 else "all" + ) # Replace hyphens with underscores + + # Return the transformed values + return {"name": first_part, "subject": second_part} + + if trainer.args.bench_dataset == "mmlu-zs": + bench_dataset = load_dataset( + "openaccess-ai-collective/mmlu-evals", + data_files={ + "eval": "zero_shot_mmlu_val.json", + "test": "zero_shot_mmlu_test.json", + }, + ) + # bench_dataset = bench_dataset.remove_columns("subject") + # MMLU Five-shot (Eval/Test only) + elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]: + bench_dataset = load_dataset( + "openaccess-ai-collective/mmlu-evals", + data_files={ + "eval": "five_shot_mmlu_val.json", + "test": "five_shot_mmlu_test.json", + }, + ) + # bench_dataset = bench_dataset.remove_columns('subject') + elif "/" in trainer.args.bench_dataset: + bench_ds = trainer.args.bench_dataset + bench_ds_name = "/".join(bench_ds.split("/", 2)[:2]) + bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:]) + bench_dataset = load_dataset( + bench_ds_name, + data_files={ + "eval": bench_ds_data_file, + }, + ) + bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject) + else: + raise ValueError( + f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args" + ) + bench_dataset = bench_dataset[trainer.args.bench_split] + if trainer.args.max_bench_samples is not None: + bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples)) + + def tokenize_evals(example): + source = f"{tokenizer.bos_token}{example['input']}" + target = f"{example['output']}{tokenizer.eos_token}" + + tokenized_source = tokenizer( + source, + max_length=2048, + truncation=True, + add_special_tokens=False, + ) + tokenized_target = tokenizer( + target, + max_length=2048, + truncation=True, + add_special_tokens=False, + ) + input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"] + labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[ + "input_ids" + ] + + return { + "input_ids": input_ids, + "labels": labels, + "subject": example["subject"], + } + + with zero_first(is_main_process()): + bench_dataset = bench_dataset.map(tokenize_evals) + bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx) + + class BenchEvalCallback(TrainerCallback): + """ + TrainerCallback that runs the MMLU evals + """ + + def on_evaluate( + self, + args: AxolotlTrainingArguments, + state: TrainerState, # pylint: disable=unused-argument + control: TrainerControl, # pylint: disable=unused-argument + metrics: Dict[str, float], # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument + ): + data_loader = trainer.get_bench_dataloader( + bench_dataset.remove_columns(["input", "subject", "output", "name"]) + ) + trainer.model.eval() + preds, refs = [], [] + loss_bench = 0 + for batch in tqdm(data_loader, total=len(data_loader)): + (loss, logits, labels) = trainer.prediction_step( + trainer.model, + batch, + prediction_loss_only=False, + ) + # There are two tokens, the output, and eos token. + for i, logit in enumerate(logits): + label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[ + 0 + ][0] + logit_abcd = logit[label_non_zero_id - 1][abcd_idx] + preds.append(torch.argmax(logit_abcd).item()) + labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0] + refs += [ + abcd_idx.index(label) if label in abcd_idx else -1 + for label in labels.tolist() + ] + loss_bench += loss.item() + # Extract results by subject. + bench_name = bench_dataset["name"] + bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)} + for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name + bench_names[s]["preds"].append(p) + bench_names[s]["refs"].append(r) + barrier() + local_bench_names = bench_names + gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())] + # Gather results from all GPUs to GPU 0 + + loss_bench_ranks = gather_scalar_from_all_ranks( + lambda: loss_bench, get_world_size() + ) + len_data_loader_ranks = gather_scalar_from_all_ranks( + lambda: len(data_loader), get_world_size() + ) + + if not is_main_process(): + dist.gather_object(local_bench_names, dst=0) + else: + dist.gather_object(local_bench_names, gathered_bench_names, dst=0) + bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks) + results = {"bench_loss": bench_loss} + + # Combine results from all GPUs + combined_bench_names: Dict[str, Dict[str, List]] = {} + for bench_name in gathered_bench_names: + for name, data in bench_name.items(): + if name not in combined_bench_names: + combined_bench_names[name] = {"refs": [], "preds": []} + combined_bench_names[name]["refs"].extend(data["refs"]) + combined_bench_names[name]["preds"].extend(data["preds"]) + + bench_scores = [] + for ( + bench_name + ) in combined_bench_names: # pylint: disable=consider-using-dict-items + bench_score = accuracy.compute( + references=combined_bench_names[bench_name]["refs"], + predictions=combined_bench_names[bench_name]["preds"], + )["accuracy"] + if not pd.isna(bench_score): + results[ + f"bench_{bench_split}_accuracy_{bench_name}" + ] = bench_score + bench_scores.append(bench_score) + else: + results[f"bench_{bench_split}_accuracy_{bench_name}"] = 0.0 + bench_scores.append(0.0) + results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores) + trainer.log(results) + + return BenchEvalCallback diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index b3ea07c05..38d0d1e05 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -1,8 +1,10 @@ """ utility helpers for distributed checks """ +import os from contextlib import contextmanager +import torch import torch.distributed as dist from accelerate import Accelerator @@ -43,6 +45,10 @@ def is_main_process(): return dist.get_rank() == 0 +def get_world_size(): + return int(os.getenv("WORLD_SIZE", "1")) + + @contextmanager def zero_first(is_main): """ @@ -53,3 +59,35 @@ def zero_first(is_main): yield if is_main: # then rank 0 waits after it has run the context barrier() + + +def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name + """ + Run a callable 'fn' on all ranks and gather the results on the specified rank. + + Args: + - fn (callable): A function that computes the value. This should not have any side effects. + - rank (int, optional): The rank that gathers the values. Default is 0. + - world_size (int, optional): Total number of processes in the current distributed setup. + + Returns: + - A list of computed values from all ranks if on the gathering rank, otherwise None. + """ + value_scalar = fn() + value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() + + if not is_main_process(): + dist.gather(value_tensor, dst=0) + else: + gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] + dist.gather(value_tensor, gather_list=gathered_tensors, dst=0) + + # Convert tensors back to their original type (int or float) + gathered_values = [] + for tensor in gathered_tensors: + if tensor == tensor.int(): + gathered_values.append(int(tensor.item())) + else: + gathered_values.append(float(tensor.item())) + return gathered_values + return None diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index fcbdd6d3e..37578908e 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -12,9 +12,15 @@ from typing import Optional, Union import numpy as np import torch.cuda +import transformers from datasets import Dataset, set_caching_enabled from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import DataLoader, DistributedSampler, RandomSampler +from torch.utils.data import ( + DataLoader, + DistributedSampler, + RandomSampler, + SequentialSampler, +) from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_pt_utils import SequentialDistributedSampler @@ -23,6 +29,7 @@ from axolotl.utils.callbacks import ( GPUStatsCallback, SaveBetterTransformerModelCallback, SavePeftModelCallback, + bench_eval_callback_factory, ) from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.dataloader import MultipackDistributedDataloader @@ -127,6 +134,27 @@ class AxolotlTrainingArguments(TrainingArguments): default=None, metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, ) + bench_split: Optional[str] = field( + default="eval", metadata={"help": "The benchmark split to run on"} + ) + bench_dataset: Optional[str] = field( + default="pharaouk/dharma-1/dharma_1_mini.json", + metadata={ + "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" + }, + ) + do_bench_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Benchmark evaluation."} + ) + max_bench_samples: Optional[int] = field( + default=None, + metadata={ + "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." + }, + ) + bench_source_max_len: int = field( + default=2048, metadata={"help": "Maximum source sequence length for bench."} + ) class AxolotlTrainer(Trainer): @@ -136,6 +164,10 @@ class AxolotlTrainer(Trainer): args = None # type: AxolotlTrainingArguments + def __init__(self, *args, bench_data_collator=None, **kwargs): + self.bench_data_collator = bench_data_collator + super().__init__(*args, **kwargs) + def create_scheduler( self, num_training_steps: int, optimizer: torch.optim.Optimizer = None ): @@ -226,6 +258,31 @@ class AxolotlTrainer(Trainer): ) return super().get_eval_dataloader(eval_dataset) + def _get_bench_sampler( + self, bench_dataset: Dataset + ) -> Optional[torch.utils.data.Sampler]: + if self.args.world_size <= 1: + return SequentialSampler(bench_dataset) + return None + + def get_bench_dataloader( + self, + bench_dataset: Dataset, + ) -> Union[DataLoader, MultipackDistributedDataloader]: + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": self.bench_data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + + if not isinstance(bench_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + + return DataLoader(bench_dataset, **dataloader_params) + # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) + def compute_loss(self, model, inputs, return_outputs=False): # use one's weighted cross entropy loss calc # if self.args.sample_packing: @@ -517,6 +574,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ "steps" if cfg.save_steps else "epoch" ) + if cfg.do_bench_eval: + training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval + if cfg.bench_dataset: + training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset + training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg max_steps=total_num_steps if cfg.max_steps else -1, max_seq_length=cfg.sequence_len, @@ -629,8 +691,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ return_tensors="pt", **data_collator_kwargs, ), + bench_data_collator=transformers.DataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + **data_collator_kwargs, + ), callbacks=callbacks, **trainer_kwargs, ) + if cfg.do_bench_eval: + trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer)) + return trainer