Compare commits

...

11 Commits

Author SHA1 Message Date
Wing Lian
53ce90d21e add sync_model_states parameter to fix resume from checkpoint with fsdp
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
fix formatting for linter
fixes FSDP resume from checkpoint (unpacked only)
chore: fix linter
chore: lint
2023-08-30 21:15:50 -07:00
Wing Lian
76576323df add eval benchmark callback (#441)
* add mmlu callback

* use hf dataset for mmlu evals

* default to mmlu-zs

* make sure to define all the explicit positional args

* include metrics in callback

* another callback fix for collator max len attribute

* fix mmlu evals

* sample benchmarks, ensure we drop long samples

* fix the data file

* fix elif and add better messaging

* more fixes

* rename mmlu to bench

* more fixes

* dataset handling and aggregate across benchmark

* better handling when no subjects

* benchmark callback has its own dataloader and collator

* fixes

* updated dataset

* more fixes

* missing transformers import

* improve support for customized dataset for bench evals

* gather benchmarks from all ranks

* fix for gather across multiple gpus
2023-08-29 13:24:19 -07:00
Wing Lian
548787daae customizable ascii art (#506) 2023-08-29 10:13:42 -07:00
Wing Lian
5ac3392075 support for datasets with multiple names (#480)
* support for datasets with multiple names

* update docs
2023-08-29 06:18:17 -07:00
Aman Gupta Karmani
e356b297cb remove --force-reinstall from Dockerfile to ensure correct pytorch version (#492) 2023-08-29 06:17:51 -07:00
NanoCode012
48c56470d0 Fix(doc): Clarify no amp to full yaml docs (#496) 2023-08-29 06:17:37 -07:00
Maxime
36b2e1cfee tweak: use default config file when only one file is present (#501) 2023-08-29 06:17:10 -07:00
Wing Lian
125cccb786 Refactor train cfg cli (#499)
* wip to cleanup cfg cli options

* fix launcher

* fix cli args
2023-08-29 05:37:53 -07:00
Aman Karmani
fd55bc87e2 use math.ceil instead of round /cc #498 2023-08-29 01:03:41 +00:00
Birch-san
8e197f6fb4 pad_to_worst_case_seq_len boolean, for testing memory limits (#498)
* pad_to_worst_case_seq_len boolean, for testing memory limits

* remove collator_pad_to_longest option since it does nothing

see docs: https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorWithPadding.padding

True and "longest" mean the same thing

* rename to `pad_to_sequence_len, and ensure 64 alignment

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
2023-08-28 18:47:16 -04:00
Aman Karmani
267b7b24e5 simplify linear layer locator 2023-08-28 09:45:16 -04:00
12 changed files with 526 additions and 97 deletions

View File

@@ -328,6 +328,15 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
name: enron_emails name: enron_emails
type: completion # format from earlier type: completion # format from earlier
# huggingface repo with multiple named configurations/subsets
datasets:
- path: bigcode/commitpackft
name:
- ruby
- python
- typescript
type: ... # unimplemented custom format
# local # local
datasets: datasets:
- path: data.jsonl # or json - path: data.jsonl # or json
@@ -407,6 +416,10 @@ fp16: true
# Use CUDA tf32 # Use CUDA tf32
tf32: true # require >=ampere tf32: true # require >=ampere
# No AMP (automatic mixed precision)
bfloat16: true # require >=ampere
float16: true
# a list of one or more datasets to finetune the model with # a list of one or more datasets to finetune the model with
datasets: datasets:
# hf dataset repo | "json" for local dataset, make sure to fill data_files # hf dataset repo | "json" for local dataset, make sure to fill data_files
@@ -459,6 +472,9 @@ dataset_shard_idx:
# the maximum length of an input to train with, this should typically be less than 2048 # the maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048 # as most models have a token/context limit of 2048
sequence_len: 2048 sequence_len: 2048
# pad inputs so each step uses constant sized buffers
# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len:
# max sequence length to concatenate training samples together up to # max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED # FutureWarning: This will soon be DEPRECATED
@@ -610,9 +626,6 @@ deepspeed:
# Path to torch distx for optim 'adamw_anyprecision' # Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path: torchdistx_path:
# Set padding for data collator to 'longest'
collator_pad_to_longest:
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize # Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
pretraining_dataset: pretraining_dataset:
@@ -652,6 +665,7 @@ fsdp:
fsdp_config: fsdp_config:
fsdp_offload_params: true fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
``` ```

View File

@@ -11,7 +11,7 @@ RUN apt-get update && \
WORKDIR /workspace WORKDIR /workspace
RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/peft.git@main" RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main"
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \ RUN cd axolotl && \

View File

@@ -47,4 +47,3 @@ local_rank:
gradient_checkpointing: true gradient_checkpointing: true
fsdp: fsdp:
fsdp_config: fsdp_config:
collator_pad_to_longest: true

View File

@@ -4,6 +4,7 @@ transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
addict addict
evaluate
fire fire
PyYAML>=6.0 PyYAML>=6.0
datasets datasets
@@ -24,3 +25,4 @@ rouge-score==0.1.2
scipy scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml
art

View File

@@ -6,14 +6,17 @@ import os
import random import random
import signal import signal
import sys import sys
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import fire import fire
import torch import torch
import transformers
import yaml import yaml
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
from art import text2art
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer from transformers import GenerationConfig, TextStreamer
@@ -22,7 +25,7 @@ from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_model_config, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars from axolotl.utils.wandb import setup_wandb_env_vars
@@ -37,16 +40,26 @@ LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def print_axolotl_text_art(): @dataclass
ascii_art = """ class TrainerCliArgs:
dP dP dP """
88 88 88 dataclass representing the various non-training arguments
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88 """
88' `88 `8bd8' 88' `88 88 88' `88 88 88
88. .88 .d88b. 88. .88 88 88. .88 88 88
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
"""
debug: bool = field(default=False)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
def print_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(" axolotl", font=font)
if is_main_process(): if is_main_process():
print(ascii_art) print(ascii_art)
@@ -61,6 +74,8 @@ def get_multi_line_input() -> Optional[str]:
def do_inference(cfg, model, tokenizer, prompter: Optional[str]): def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
if prompter == "None":
prompter = None
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"} default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items(): for token, symbol in default_tokens.items():
@@ -135,6 +150,10 @@ def choose_config(path: Path):
"No YAML config files found in the specified directory. Are you using a .yml extension?" "No YAML config files found in the specified directory. Are you using a .yml extension?"
) )
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return yaml_files[0]
print("Choose a YAML file:") print("Choose a YAML file:")
for idx, file in enumerate(yaml_files): for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}") print(f"{idx + 1}. {file}")
@@ -158,45 +177,20 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
def train( def train(
config: Path = Path("configs/"), *,
prepare_ds_only: bool = False, cfg: DictDefault,
**kwargs, cli_args: TrainerCliArgs,
): ):
print_axolotl_text_art()
if Path(config).is_dir():
config = choose_config(config)
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
validate_config(cfg)
normalize_config(cfg)
setup_wandb_env_vars(cfg)
# load the tokenizer first # load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
if ( if not (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference cli_args.shard or cli_args.merge_lora or cli_args.inference
): # don't need to load dataset for these ): # don't need to load dataset for these
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer) train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
if cfg.debug or "debug" in kwargs: if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
check_dataset_labels( check_dataset_labels(
train_dataset.select( train_dataset.select(
@@ -205,17 +199,17 @@ def train(
tokenizer, tokenizer,
) )
if prepare_ds_only: if cli_args.prepare_ds_only:
LOG.info("Finished preparing dataset. Exiting...") LOG.info("Finished preparing dataset. Exiting...")
return return
# Load the model and tokenizer # Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...") LOG.info("loading model and (optionally) peft_config...")
model, peft_config = load_model(cfg, tokenizer) model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True
if "merge_lora" in kwargs and cfg.adapter is not None: if cli_args.merge_lora and cfg.adapter is not None:
LOG.info("running merge of LoRA with base model") LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload() model = model.merge_and_unload()
model.to(dtype=torch.float16) model.to(dtype=torch.float16)
@@ -229,18 +223,13 @@ def train(
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return return
if cfg.inference: if cli_args.inference:
LOG.info("calling do_inference function") LOG.debug("Running inference on model")
prompter: Optional[str] = "AlpacaPrompter" do_inference(cfg, model, tokenizer, prompter=cli_args.prompter)
if "prompter" in kwargs:
if kwargs["prompter"] == "None":
prompter = None
else:
prompter = kwargs["prompter"]
do_inference(cfg, model, tokenizer, prompter=prompter)
return return
if "shard" in kwargs: if cli_args.shard:
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
return return
@@ -322,5 +311,51 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def load_cfg(config: Path = Path("examples/"), **kwargs):
if Path(config).is_dir():
config = choose_config(config)
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
model_config = load_model_config(cfg)
# figure out if the model is llama
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
validate_config(cfg)
normalize_config(cfg)
setup_wandb_env_vars(cfg)
return cfg
def do_train(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
train(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(train) fire.Fire(do_train)

View File

@@ -0,0 +1,45 @@
"""
Monkeypatch to fix fsdp set state when no previous state was set
"""
import contextlib
from typing import Generator, Optional
import torch
from torch import nn
from torch.distributed.fsdp.api import (
OptimStateDictConfig,
StateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
@staticmethod
@contextlib.contextmanager
def state_dict_type_patch(
module: nn.Module,
state_dict_type: StateDictType,
state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> Generator:
prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
module,
state_dict_type,
state_dict_config,
optim_state_dict_config,
)
yield
if prev_state_dict_settings.state_dict_type:
FullyShardedDataParallel.set_state_dict_type(
module,
prev_state_dict_settings.state_dict_type,
prev_state_dict_settings.state_dict_config,
prev_state_dict_settings.optim_state_dict_config,
)
def replace_fsdp_state_dict_type():
torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel.state_dict_type = (
state_dict_type_patch
)

View File

@@ -1,9 +1,19 @@
"""Callbacks for Trainer class""" """Callbacks for Trainer class"""
from __future__ import annotations
import logging import logging
import os import os
from typing import TYPE_CHECKING, Dict, List
import evaluate
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import ( from transformers import (
TrainerCallback, TrainerCallback,
TrainerControl, TrainerControl,
@@ -13,8 +23,19 @@ from transformers import (
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import (
barrier,
gather_scalar_from_all_ranks,
get_world_size,
is_main_process,
zero_first,
)
if TYPE_CHECKING:
from axolotl.utils.trainer import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks") LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -96,3 +117,192 @@ class GPUStatsCallback(
log_gpu_memory_usage(LOG, "while training", self.cfg.device) log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True self.logged = True
return control return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [
tokenizer("A", add_special_tokens=False).input_ids[0],
tokenizer("B", add_special_tokens=False).input_ids[0],
tokenizer("C", add_special_tokens=False).input_ids[0],
tokenizer("D", add_special_tokens=False).input_ids[0],
tokenizer("E", add_special_tokens=False).input_ids[0],
tokenizer("F", add_special_tokens=False).input_ids[0],
tokenizer("G", add_special_tokens=False).input_ids[0],
]
bench_split = "eval"
def transform_bench_subject(example):
# Split on ':' and trim whitespace
parts = example["subject"].split(":")
first_part = (
parts[0].strip().lower().replace("-", "_")
) # Lowercase the first part
second_part = (
parts[1].strip().replace("-", "_") if len(parts) > 1 else "all"
) # Replace hyphens with underscores
# Return the transformed values
return {"name": first_part, "subject": second_part}
if trainer.args.bench_dataset == "mmlu-zs":
bench_dataset = load_dataset(
"openaccess-ai-collective/mmlu-evals",
data_files={
"eval": "zero_shot_mmlu_val.json",
"test": "zero_shot_mmlu_test.json",
},
)
# bench_dataset = bench_dataset.remove_columns("subject")
# MMLU Five-shot (Eval/Test only)
elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]:
bench_dataset = load_dataset(
"openaccess-ai-collective/mmlu-evals",
data_files={
"eval": "five_shot_mmlu_val.json",
"test": "five_shot_mmlu_test.json",
},
)
# bench_dataset = bench_dataset.remove_columns('subject')
elif "/" in trainer.args.bench_dataset:
bench_ds = trainer.args.bench_dataset
bench_ds_name = "/".join(bench_ds.split("/", 2)[:2])
bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:])
bench_dataset = load_dataset(
bench_ds_name,
data_files={
"eval": bench_ds_data_file,
},
)
bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject)
else:
raise ValueError(
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
)
bench_dataset = bench_dataset[trainer.args.bench_split]
if trainer.args.max_bench_samples is not None:
bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples))
def tokenize_evals(example):
source = f"{tokenizer.bos_token}{example['input']}"
target = f"{example['output']}{tokenizer.eos_token}"
tokenized_source = tokenizer(
source,
max_length=2048,
truncation=True,
add_special_tokens=False,
)
tokenized_target = tokenizer(
target,
max_length=2048,
truncation=True,
add_special_tokens=False,
)
input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"]
labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[
"input_ids"
]
return {
"input_ids": input_ids,
"labels": labels,
"subject": example["subject"],
}
with zero_first(is_main_process()):
bench_dataset = bench_dataset.map(tokenize_evals)
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
class BenchEvalCallback(TrainerCallback):
"""
TrainerCallback that runs the MMLU evals
"""
def on_evaluate(
self,
args: AxolotlTrainingArguments,
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
metrics: Dict[str, float], # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
data_loader = trainer.get_bench_dataloader(
bench_dataset.remove_columns(["input", "subject", "output", "name"])
)
trainer.model.eval()
preds, refs = [], []
loss_bench = 0
for batch in tqdm(data_loader, total=len(data_loader)):
(loss, logits, labels) = trainer.prediction_step(
trainer.model,
batch,
prediction_loss_only=False,
)
# There are two tokens, the output, and eos token.
for i, logit in enumerate(logits):
label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[
0
][0]
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
preds.append(torch.argmax(logit_abcd).item())
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
refs += [
abcd_idx.index(label) if label in abcd_idx else -1
for label in labels.tolist()
]
loss_bench += loss.item()
# Extract results by subject.
bench_name = bench_dataset["name"]
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
bench_names[s]["preds"].append(p)
bench_names[s]["refs"].append(r)
barrier()
local_bench_names = bench_names
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
# Gather results from all GPUs to GPU 0
loss_bench_ranks = gather_scalar_from_all_ranks(
lambda: loss_bench, get_world_size()
)
len_data_loader_ranks = gather_scalar_from_all_ranks(
lambda: len(data_loader), get_world_size()
)
if not is_main_process():
dist.gather_object(local_bench_names, dst=0)
else:
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
results = {"bench_loss": bench_loss}
# Combine results from all GPUs
combined_bench_names: Dict[str, Dict[str, List]] = {}
for bench_name in gathered_bench_names:
for name, data in bench_name.items():
if name not in combined_bench_names:
combined_bench_names[name] = {"refs": [], "preds": []}
combined_bench_names[name]["refs"].extend(data["refs"])
combined_bench_names[name]["preds"].extend(data["preds"])
bench_scores = []
for (
bench_name
) in combined_bench_names: # pylint: disable=consider-using-dict-items
bench_score = accuracy.compute(
references=combined_bench_names[bench_name]["refs"],
predictions=combined_bench_names[bench_name]["preds"],
)["accuracy"]
if not pd.isna(bench_score):
results[
f"bench_{bench_split}_accuracy_{bench_name}"
] = bench_score
bench_scores.append(bench_score)
else:
results[f"bench_{bench_split}_accuracy_{bench_name}"] = 0.0
bench_scores.append(0.0)
results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores)
trainer.log(results)
return BenchEvalCallback

View File

@@ -152,6 +152,16 @@ def validate_config(cfg):
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp: if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
raise ValueError("FSDP is not supported for falcon models") raise ValueError("FSDP is not supported for falcon models")
if (
cfg.fsdp
and cfg.fsdp_config
and cfg.fsdp_config.fsdp_state_dict_type
and not cfg.fsdp_config.fsdp_sync_module_states
):
LOG.warning(
"We recommend setting fsdp_config.fsdp_sync_module_states to `true`"
)
if ( if (
cfg.base_model and "mpt" in cfg.base_model.lower() cfg.base_model and "mpt" in cfg.base_model.lower()
) and cfg.gradient_checkpointing: ) and cfg.gradient_checkpointing:

View File

@@ -134,8 +134,17 @@ def load_tokenized_prepared_datasets(
seed = 42 seed = 42
datasets = [] datasets = []
def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
else:
yield dataset
# pylint: disable=invalid-name # pylint: disable=invalid-name
for d in cfg.datasets: for d in for_d_in_datasets(cfg.datasets):
ds: Union[Dataset, DatasetDict] = None ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False ds_from_hub = False
try: try:

View File

@@ -1,8 +1,10 @@
""" """
utility helpers for distributed checks utility helpers for distributed checks
""" """
import os
from contextlib import contextmanager from contextlib import contextmanager
import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate import Accelerator from accelerate import Accelerator
@@ -43,6 +45,10 @@ def is_main_process():
return dist.get_rank() == 0 return dist.get_rank() == 0
def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
@contextmanager @contextmanager
def zero_first(is_main): def zero_first(is_main):
""" """
@@ -53,3 +59,35 @@ def zero_first(is_main):
yield yield
if is_main: # then rank 0 waits after it has run the context if is_main: # then rank 0 waits after it has run the context
barrier() barrier()
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
Run a callable 'fn' on all ranks and gather the results on the specified rank.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that gathers the values. Default is 0.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
if not is_main_process():
dist.gather(value_tensor, dst=0)
else:
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
# Convert tensors back to their original type (int or float)
gathered_values = []
for tensor in gathered_tensors:
if tensor == tensor.int():
gathered_values.append(int(tensor.item()))
else:
gathered_values.append(float(tensor.item()))
return gathered_values
return None

View File

@@ -5,12 +5,13 @@ import logging
import math import math
import os import os
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401 from typing import Optional, Tuple # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig
from transformers import ( # noqa: F401 from transformers import ( # noqa: F401
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
@@ -23,13 +24,17 @@ from transformers import ( # noqa: F401
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
if TYPE_CHECKING:
from peft import PeftConfig # noqa: F401
from axolotl.utils.dict import DictDefault # noqa: F401 def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code: bool = False or cfg.trust_remote_code
return AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code
)
def load_tokenizer(cfg): def load_tokenizer(cfg):
@@ -86,8 +91,10 @@ def load_tokenizer(cfg):
def load_model( def load_model(
cfg, tokenizer cfg: DictDefault,
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]] tokenizer: PreTrainedTokenizerBase,
inference: bool = False,
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
""" """
Load a model for a given configuration and tokenizer. Load a model for a given configuration and tokenizer.
""" """
@@ -97,14 +104,9 @@ def load_model(
# TODO refactor as a kwarg # TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit load_in_8bit = cfg.load_in_8bit
cfg.is_llama_derived_model = (
"llama" in base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
or cfg.is_llama_derived_model
)
if cfg.is_llama_derived_model and cfg.flash_attention: if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and not cfg.inference: if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn, replace_llama_attn_with_flash_attn,
) )
@@ -146,7 +148,7 @@ def load_model(
if ( if (
cfg.is_llama_derived_model cfg.is_llama_derived_model
and (cfg.max_packed_sequence_len or cfg.sample_packing) and (cfg.max_packed_sequence_len or cfg.sample_packing)
and not cfg.inference and not inference
): ):
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
@@ -424,15 +426,15 @@ def load_model(
return model, lora_config return model, lora_config
def load_adapter(model, cfg, adapter): def load_adapter(model, cfg, adapter, inference=False):
# type: (PreTrainedModel, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
if adapter is None: if adapter is None:
return model, None return model, None
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads() model.enable_input_require_grads()
if adapter in ["lora", "qlora"]: if adapter in ["lora", "qlora"]:
return load_lora(model, cfg) return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter": if adapter == "llama-adapter":
return load_llama_adapter(model, cfg) return load_llama_adapter(model, cfg)
@@ -464,12 +466,8 @@ def load_llama_adapter(model, cfg):
return model, peft_config return model, peft_config
def find_all_linear_names(bits, model): def find_all_linear_names(model):
cls = ( cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
bnb.nn.Linear4bit
if bits == 4
else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
)
lora_module_names = set() lora_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, cls): if isinstance(module, cls):
@@ -482,21 +480,15 @@ def find_all_linear_names(bits, model):
return list(lora_module_names) return list(lora_module_names)
def load_lora(model, cfg): def load_lora(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import LoraConfig, PeftModel, get_peft_model from peft import LoraConfig, PeftModel, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or []) lora_target_modules = list(cfg.lora_target_modules or [])
if cfg.lora_target_linear: if cfg.lora_target_linear:
bits = None linear_names = find_all_linear_names(model)
if cfg.load_in_4bit:
bits = 4
elif cfg.load_in_8bit:
bits = 8
linear_names = find_all_linear_names(bits, model)
LOG.info(f"found linear modules: {repr(linear_names)}") LOG.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names)) lora_target_modules = list(set(lora_target_modules + linear_names))
@@ -516,7 +508,7 @@ def load_lora(model, cfg):
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.lora_model_dir,
is_trainable=not cfg.inference, is_trainable=(not inference),
) )
else: else:
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)

View File

@@ -12,9 +12,15 @@ from typing import Optional, Union
import numpy as np import numpy as np
import torch.cuda import torch.cuda
import transformers
from datasets import Dataset, set_caching_enabled from datasets import Dataset, set_caching_enabled
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler from torch.utils.data import (
DataLoader,
DistributedSampler,
RandomSampler,
SequentialSampler,
)
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import SequentialDistributedSampler from transformers.trainer_pt_utils import SequentialDistributedSampler
@@ -23,6 +29,7 @@ from axolotl.utils.callbacks import (
GPUStatsCallback, GPUStatsCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
SavePeftModelCallback, SavePeftModelCallback,
bench_eval_callback_factory,
) )
from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.dataloader import MultipackDistributedDataloader
@@ -127,6 +134,27 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None, default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
) )
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
class AxolotlTrainer(Trainer): class AxolotlTrainer(Trainer):
@@ -136,6 +164,10 @@ class AxolotlTrainer(Trainer):
args = None # type: AxolotlTrainingArguments args = None # type: AxolotlTrainingArguments
def __init__(self, *args, bench_data_collator=None, **kwargs):
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
): ):
@@ -226,6 +258,31 @@ class AxolotlTrainer(Trainer):
) )
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> Union[DataLoader, MultipackDistributedDataloader]:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False): def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc # use one's weighted cross entropy loss calc
# if self.args.sample_packing: # if self.args.sample_packing:
@@ -414,6 +471,9 @@ def setup_fsdp_envs(cfg):
os.environ[ os.environ[
"FSDP_TRANSFORMER_CLS_TO_WRAP" "FSDP_TRANSFORMER_CLS_TO_WRAP"
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
from axolotl.monkeypatch.fsdp import replace_fsdp_state_dict_type
replace_fsdp_state_dict_type()
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
@@ -517,6 +577,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"steps" if cfg.save_steps else "epoch" "steps" if cfg.save_steps else "epoch"
) )
if cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1, max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len, max_seq_length=cfg.sequence_len,
@@ -585,10 +650,12 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
callbacks.append(SaveBetterTransformerModelCallback) callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = { data_collator_kwargs = {
"padding": True, "padding": True, # True/"longest" is the default
} }
if cfg.collator_pad_to_longest: if cfg.pad_to_sequence_len:
data_collator_kwargs["padding"] = "longest" data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
cfg.sequence_len / 64
)
else: else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
@@ -627,8 +694,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
return_tensors="pt", return_tensors="pt",
**data_collator_kwargs, **data_collator_kwargs,
), ),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=callbacks, callbacks=callbacks,
**trainer_kwargs, **trainer_kwargs,
) )
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
return trainer return trainer