Compare commits

..

17 Commits

Author SHA1 Message Date
Wing Lian
0026fcc3df remove torch install for now
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-09-01 08:15:22 -07:00
Wing Lian
b448c77148 address pr feedback 2023-08-29 22:45:22 -07:00
Wing Lian
c820d04669 gptq doesn't play well with sample packing 2023-08-29 12:15:31 -07:00
Wing Lian
588cd65a64 fix setup.py to use extra index url
install torch for tests
fix cuda version for autogptq index
set torch in requirements so that it installs properly
move gptq install around to work with github cicd
2023-08-29 12:02:51 -07:00
Wing Lian
caa80e891d don't need explicit peft install for tests 2023-08-29 12:02:51 -07:00
Wing Lian
ac37753aa2 remove old gptq docker 2023-08-29 12:02:50 -07:00
Wing Lian
a29560004b more tweaks and add yml 2023-08-29 12:02:00 -07:00
Wing Lian
1deb767fe8 auto gptq support 2023-08-29 11:31:14 -07:00
Wing Lian
548787daae customizable ascii art (#506) 2023-08-29 10:13:42 -07:00
Wing Lian
5ac3392075 support for datasets with multiple names (#480)
* support for datasets with multiple names

* update docs
2023-08-29 06:18:17 -07:00
Aman Gupta Karmani
e356b297cb remove --force-reinstall from Dockerfile to ensure correct pytorch version (#492) 2023-08-29 06:17:51 -07:00
NanoCode012
48c56470d0 Fix(doc): Clarify no amp to full yaml docs (#496) 2023-08-29 06:17:37 -07:00
Maxime
36b2e1cfee tweak: use default config file when only one file is present (#501) 2023-08-29 06:17:10 -07:00
Wing Lian
125cccb786 Refactor train cfg cli (#499)
* wip to cleanup cfg cli options

* fix launcher

* fix cli args
2023-08-29 05:37:53 -07:00
Aman Karmani
fd55bc87e2 use math.ceil instead of round /cc #498 2023-08-29 01:03:41 +00:00
Birch-san
8e197f6fb4 pad_to_worst_case_seq_len boolean, for testing memory limits (#498)
* pad_to_worst_case_seq_len boolean, for testing memory limits

* remove collator_pad_to_longest option since it does nothing

see docs: https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorWithPadding.padding

True and "longest" mean the same thing

* rename to `pad_to_sequence_len, and ensure 64 alignment

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
2023-08-28 18:47:16 -04:00
Aman Karmani
267b7b24e5 simplify linear layer locator 2023-08-28 09:45:16 -04:00
17 changed files with 286 additions and 624 deletions

View File

@@ -23,11 +23,6 @@ jobs:
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.1
axolotl_extras: gptq
runs-on: self-hosted
steps:
- name: Checkout
@@ -73,11 +68,6 @@ jobs:
pytorch: 2.0.1
axolotl_extras:
is_latest: true
- cuda: 118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.1
axolotl_extras: gptq
runs-on: self-hosted
steps:
- name: Checkout

View File

@@ -24,7 +24,7 @@ jobs:
- name: Install dependencies
run: |
pip install -e .[peft]
pip install -e .
pip install -r requirements-tests.txt
- name: Run tests

View File

@@ -328,6 +328,15 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
name: enron_emails
type: completion # format from earlier
# huggingface repo with multiple named configurations/subsets
datasets:
- path: bigcode/commitpackft
name:
- ruby
- python
- typescript
type: ... # unimplemented custom format
# local
datasets:
- path: data.jsonl # or json
@@ -407,6 +416,10 @@ fp16: true
# Use CUDA tf32
tf32: true # require >=ampere
# No AMP (automatic mixed precision)
bfloat16: true # require >=ampere
float16: true
# a list of one or more datasets to finetune the model with
datasets:
# hf dataset repo | "json" for local dataset, make sure to fill data_files
@@ -459,6 +472,9 @@ dataset_shard_idx:
# the maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048
sequence_len: 2048
# pad inputs so each step uses constant sized buffers
# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len:
# max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED
@@ -610,9 +626,6 @@ deepspeed:
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:
# Set padding for data collator to 'longest'
collator_pad_to_longest:
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
pretraining_dataset:

View File

@@ -11,14 +11,13 @@ RUN apt-get update && \
WORKDIR /workspace
RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/peft.git@main"
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
pip install -e .[flash-attn,gptq,$AXOLOTL_EXTRAS]; \
else \
pip install -e .[flash-attn]; \
pip install -e .[flash-attn,gptq]; \
fi
# fix so that git fetch/pull from remote works

View File

@@ -1,8 +0,0 @@
# LLaMa 7B using LoRA
This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
```shell
accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
```

View File

@@ -1,63 +0,0 @@
base_model: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
base_model_config: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code:
load_in_8bit: true
gptq: true
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: llama-7b-lora-int4
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./llama-7b-lora-int4
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0000002
train_on_inputs: false
group_by_length: false
fp16: true
bf16: false
tf32: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 5
xformers_attention:
flash_attention:
gradient_checkpointing: true
gptq_groupsize: 128
gptq_model_v1: false
warmup_steps: 20
eval_steps: 110
save_steps: 660
debug:
deepspeed:
weight_decay: 0.0001
fsdp:
fsdp_config:
tokens:
pad_token: "<pad>"
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,76 @@
base_model: TheBloke/Llama-2-7B-GPTQ
base_model_config: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false
gptq: true
gptq_bits: 4
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
tokenizer_use_fast: true
tokenizer_legacy: true
load_in_8bit: false
load_in_4bit: false
strict: false
push_dataset_to_hub:
hf_use_auth_token: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: lora
lora_model_dir:
sequence_len: 4096
sample_packing:
lora_r: 8
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- k_proj
- o_proj
- q_proj
- v_proj
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./model-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_torch
adam_beta2: 0.95
adam_eps: 0.00001
max_grad_norm: 1.0
torchdistx_path:
lr_scheduler: cosine
lr_quadratic_warmup: true
learning_rate: 0.000017
train_on_inputs: false
group_by_length: false
bf16: false
fp16: false
float16: true
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
sdp_attention:
flash_optimum:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 100
eval_steps:
save_steps:
debug:
deepspeed:
weight_decay: 0.1
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -47,4 +47,3 @@ local_rank:
gradient_checkpointing: true
fsdp:
fsdp_config:
collator_pad_to_longest: true

View File

@@ -1,10 +1,13 @@
--extra-index-url https://download.pytorch.org/whl/cu118
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
torch==2.0.1
auto-gptq
packaging
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
addict
evaluate
fire
PyYAML>=6.0
datasets
@@ -25,3 +28,4 @@ rouge-score==0.1.2
scipy
scikit-learn==1.2.2
pynvml
art

View File

@@ -6,14 +6,17 @@ import os
import random
import signal
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import fire
import torch
import transformers
import yaml
# add src to the pythonpath so we don't need to pip install this
from art import text2art
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
@@ -22,7 +25,7 @@ from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.models import load_model, load_model_config, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars
@@ -37,16 +40,26 @@ LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def print_axolotl_text_art():
ascii_art = """
dP dP dP
88 88 88
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
88' `88 `8bd8' 88' `88 88 88' `88 88 88
88. .88 .d88b. 88. .88 88 88. .88 88 88
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
"""
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
debug: bool = field(default=False)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
def print_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(" axolotl", font=font)
if is_main_process():
print(ascii_art)
@@ -61,6 +74,8 @@ def get_multi_line_input() -> Optional[str]:
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
if prompter == "None":
prompter = None
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
@@ -135,6 +150,10 @@ def choose_config(path: Path):
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return yaml_files[0]
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
@@ -158,45 +177,20 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
def train(
config: Path = Path("configs/"),
prepare_ds_only: bool = False,
**kwargs,
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
print_axolotl_text_art()
if Path(config).is_dir():
config = choose_config(config)
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
validate_config(cfg)
normalize_config(cfg)
setup_wandb_env_vars(cfg)
# load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
if not (
cli_args.shard or cli_args.merge_lora or cli_args.inference
): # don't need to load dataset for these
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
if cfg.debug or "debug" in kwargs:
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
@@ -205,17 +199,17 @@ def train(
tokenizer,
)
if prepare_ds_only:
if cli_args.prepare_ds_only:
LOG.info("Finished preparing dataset. Exiting...")
return
# Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...")
model, peft_config = load_model(cfg, tokenizer)
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
safe_serialization = cfg.save_safetensors is True
if "merge_lora" in kwargs and cfg.adapter is not None:
if cli_args.merge_lora and cfg.adapter is not None:
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
@@ -229,18 +223,13 @@ def train(
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if cfg.inference:
LOG.info("calling do_inference function")
prompter: Optional[str] = "AlpacaPrompter"
if "prompter" in kwargs:
if kwargs["prompter"] == "None":
prompter = None
else:
prompter = kwargs["prompter"]
do_inference(cfg, model, tokenizer, prompter=prompter)
if cli_args.inference:
LOG.debug("Running inference on model")
do_inference(cfg, model, tokenizer, prompter=cli_args.prompter)
return
if "shard" in kwargs:
if cli_args.shard:
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
return
@@ -322,5 +311,51 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def load_cfg(config: Path = Path("examples/"), **kwargs):
if Path(config).is_dir():
config = choose_config(config)
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
model_config = load_model_config(cfg)
# figure out if the model is llama
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
validate_config(cfg)
normalize_config(cfg)
setup_wandb_env_vars(cfg)
return cfg
def do_train(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
train(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":
fire.Fire(train)
fire.Fire(do_train)

View File

@@ -2,15 +2,27 @@
from setuptools import find_packages, setup
install_requires = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
# don't include peft yet until we check the int4
# need to manually install peft for now...
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
reqs = [r for r in reqs if "flash-attn" not in r]
reqs = [r for r in reqs if r and r[0] != "#"]
for r in reqs:
install_requires.append(r)
def parse_requirements():
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [
r.strip() for r in requirements_file.readlines() if "auto-gptq" not in r
]
for line in lines:
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif "flash-attn" not in line and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
return _install_requires, _dependency_links
install_requires, dependency_links = parse_requirements()
setup(
name="axolotl",
@@ -19,12 +31,10 @@ setup(
package_dir={"": "src"},
packages=find_packages(),
install_requires=install_requires,
dependency_links=dependency_links,
extras_require={
"gptq": [
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
],
"gptq_triton": [
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
"auto-gptq",
],
"flash-attn": [
"flash-attn==2.0.8",
@@ -32,8 +42,5 @@ setup(
"extras": [
"deepspeed",
],
"peft": [
"peft @ git+https://github.com/huggingface/peft.git",
],
},
)

View File

@@ -1,19 +1,9 @@
"""Callbacks for Trainer class"""
from __future__ import annotations
import logging
import os
from typing import TYPE_CHECKING, Dict, List
import evaluate
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import (
TrainerCallback,
TrainerControl,
@@ -23,19 +13,8 @@ from transformers import (
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import (
barrier,
gather_scalar_from_all_ranks,
get_world_size,
is_main_process,
zero_first,
)
if TYPE_CHECKING:
from axolotl.utils.trainer import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -117,192 +96,3 @@ class GPUStatsCallback(
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [
tokenizer("A", add_special_tokens=False).input_ids[0],
tokenizer("B", add_special_tokens=False).input_ids[0],
tokenizer("C", add_special_tokens=False).input_ids[0],
tokenizer("D", add_special_tokens=False).input_ids[0],
tokenizer("E", add_special_tokens=False).input_ids[0],
tokenizer("F", add_special_tokens=False).input_ids[0],
tokenizer("G", add_special_tokens=False).input_ids[0],
]
bench_split = "eval"
def transform_bench_subject(example):
# Split on ':' and trim whitespace
parts = example["subject"].split(":")
first_part = (
parts[0].strip().lower().replace("-", "_")
) # Lowercase the first part
second_part = (
parts[1].strip().replace("-", "_") if len(parts) > 1 else "all"
) # Replace hyphens with underscores
# Return the transformed values
return {"name": first_part, "subject": second_part}
if trainer.args.bench_dataset == "mmlu-zs":
bench_dataset = load_dataset(
"openaccess-ai-collective/mmlu-evals",
data_files={
"eval": "zero_shot_mmlu_val.json",
"test": "zero_shot_mmlu_test.json",
},
)
# bench_dataset = bench_dataset.remove_columns("subject")
# MMLU Five-shot (Eval/Test only)
elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]:
bench_dataset = load_dataset(
"openaccess-ai-collective/mmlu-evals",
data_files={
"eval": "five_shot_mmlu_val.json",
"test": "five_shot_mmlu_test.json",
},
)
# bench_dataset = bench_dataset.remove_columns('subject')
elif "/" in trainer.args.bench_dataset:
bench_ds = trainer.args.bench_dataset
bench_ds_name = "/".join(bench_ds.split("/", 2)[:2])
bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:])
bench_dataset = load_dataset(
bench_ds_name,
data_files={
"eval": bench_ds_data_file,
},
)
bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject)
else:
raise ValueError(
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
)
bench_dataset = bench_dataset[trainer.args.bench_split]
if trainer.args.max_bench_samples is not None:
bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples))
def tokenize_evals(example):
source = f"{tokenizer.bos_token}{example['input']}"
target = f"{example['output']}{tokenizer.eos_token}"
tokenized_source = tokenizer(
source,
max_length=2048,
truncation=True,
add_special_tokens=False,
)
tokenized_target = tokenizer(
target,
max_length=2048,
truncation=True,
add_special_tokens=False,
)
input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"]
labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[
"input_ids"
]
return {
"input_ids": input_ids,
"labels": labels,
"subject": example["subject"],
}
with zero_first(is_main_process()):
bench_dataset = bench_dataset.map(tokenize_evals)
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
class BenchEvalCallback(TrainerCallback):
"""
TrainerCallback that runs the MMLU evals
"""
def on_evaluate(
self,
args: AxolotlTrainingArguments,
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
metrics: Dict[str, float], # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
data_loader = trainer.get_bench_dataloader(
bench_dataset.remove_columns(["input", "subject", "output", "name"])
)
trainer.model.eval()
preds, refs = [], []
loss_bench = 0
for batch in tqdm(data_loader, total=len(data_loader)):
(loss, logits, labels) = trainer.prediction_step(
trainer.model,
batch,
prediction_loss_only=False,
)
# There are two tokens, the output, and eos token.
for i, logit in enumerate(logits):
label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[
0
][0]
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
preds.append(torch.argmax(logit_abcd).item())
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
refs += [
abcd_idx.index(label) if label in abcd_idx else -1
for label in labels.tolist()
]
loss_bench += loss.item()
# Extract results by subject.
bench_name = bench_dataset["name"]
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
bench_names[s]["preds"].append(p)
bench_names[s]["refs"].append(r)
barrier()
local_bench_names = bench_names
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
# Gather results from all GPUs to GPU 0
loss_bench_ranks = gather_scalar_from_all_ranks(
lambda: loss_bench, get_world_size()
)
len_data_loader_ranks = gather_scalar_from_all_ranks(
lambda: len(data_loader), get_world_size()
)
if not is_main_process():
dist.gather_object(local_bench_names, dst=0)
else:
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
results = {"bench_loss": bench_loss}
# Combine results from all GPUs
combined_bench_names: Dict[str, Dict[str, List]] = {}
for bench_name in gathered_bench_names:
for name, data in bench_name.items():
if name not in combined_bench_names:
combined_bench_names[name] = {"refs": [], "preds": []}
combined_bench_names[name]["refs"].extend(data["refs"])
combined_bench_names[name]["preds"].extend(data["preds"])
bench_scores = []
for (
bench_name
) in combined_bench_names: # pylint: disable=consider-using-dict-items
bench_score = accuracy.compute(
references=combined_bench_names[bench_name]["refs"],
predictions=combined_bench_names[bench_name]["preds"],
)["accuracy"]
if not pd.isna(bench_score):
results[
f"bench_{bench_split}_accuracy_{bench_name}"
] = bench_score
bench_scores.append(bench_score)
else:
results[f"bench_{bench_split}_accuracy_{bench_name}"] = 0.0
bench_scores.append(0.0)
results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores)
trainer.log(results)
return BenchEvalCallback

View File

@@ -97,9 +97,7 @@ def validate_config(cfg):
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
if cfg.load_4bit:
raise ValueError(
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
)
raise ValueError("cfg.load_4bit parameter has been deprecated")
if cfg.adapter == "qlora":
if cfg.merge_lora:

View File

@@ -134,8 +134,17 @@ def load_tokenized_prepared_datasets(
seed = 42
datasets = []
def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
else:
yield dataset
# pylint: disable=invalid-name
for d in cfg.datasets:
for d in for_d_in_datasets(cfg.datasets):
ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False
try:

View File

@@ -1,10 +1,8 @@
"""
utility helpers for distributed checks
"""
import os
from contextlib import contextmanager
import torch
import torch.distributed as dist
from accelerate import Accelerator
@@ -45,10 +43,6 @@ def is_main_process():
return dist.get_rank() == 0
def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
@contextmanager
def zero_first(is_main):
"""
@@ -59,35 +53,3 @@ def zero_first(is_main):
yield
if is_main: # then rank 0 waits after it has run the context
barrier()
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
Run a callable 'fn' on all ranks and gather the results on the specified rank.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that gathers the values. Default is 0.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
if not is_main_process():
dist.gather(value_tensor, dst=0)
else:
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
# Convert tensors back to their original type (int or float)
gathered_values = []
for tensor in gathered_tensors:
if tensor == tensor.int():
gathered_values.append(int(tensor.item()))
else:
gathered_values.append(float(tensor.item()))
return gathered_values
return None

View File

@@ -4,18 +4,19 @@
import logging
import math
import os
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
from typing import Optional, Tuple # noqa: F401
import bitsandbytes as bnb
import torch
import transformers
from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GPTQConfig,
LlamaConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
@@ -23,13 +24,17 @@ from transformers import ( # noqa: F401
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl")
if TYPE_CHECKING:
from peft import PeftConfig # noqa: F401
from axolotl.utils.dict import DictDefault # noqa: F401
def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code: bool = False or cfg.trust_remote_code
return AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code
)
def load_tokenizer(cfg):
@@ -86,8 +91,10 @@ def load_tokenizer(cfg):
def load_model(
cfg, tokenizer
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase,
inference: bool = False,
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""
Load a model for a given configuration and tokenizer.
"""
@@ -97,14 +104,9 @@ def load_model(
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
cfg.is_llama_derived_model = (
"llama" in base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
or cfg.is_llama_derived_model
)
if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
@@ -146,39 +148,22 @@ def load_model(
if (
cfg.is_llama_derived_model
and (cfg.max_packed_sequence_len or cfg.sample_packing)
and not cfg.inference
and not inference
):
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
LOG.info("patching _expand_mask")
hijack_expand_mask()
try:
if cfg.gptq:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
replace_peft_model_with_int4_lora_model,
)
replace_peft_model_with_int4_lora_model()
except Exception as err:
LOG.exception(err)
raise err
if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
):
try:
from peft import prepare_model_for_kbit_training
except ImportError:
# For backward compatibility
from peft import (
prepare_model_for_int8_training as prepare_model_for_kbit_training,
)
model_kwargs = {}
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.gptq:
# TODO we should figure out how read the models config.json first
model_kwargs["quantization_config"] = GPTQConfig(
bits=cfg.gptq_bits,
disable_exllama=True,
)
if cfg.adapter == "qlora" and cfg.load_in_4bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
@@ -189,45 +174,7 @@ def load_model(
bnb_4bit_quant_type="nf4",
)
try:
if cfg.gptq and cfg.is_llama_derived_model:
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
from huggingface_hub import snapshot_download
try:
snapshot_download_kwargs = {}
if cfg.base_model_ignore_patterns:
snapshot_download_kwargs[
"ignore_patterns"
] = cfg.base_model_ignore_patterns
cache_model_path = Path(
snapshot_download(base_model, **snapshot_download_kwargs)
)
files = (
list(cache_model_path.glob("*.pt"))
+ list(cache_model_path.glob("*.safetensors"))
+ list(cache_model_path.glob("*.bin"))
)
if len(files) > 0:
model_path = str(files[0])
else:
LOG.warning(
"unable to find a cached model file, this will likely fail..."
)
model_path = str(cache_model_path)
except Exception: # pylint: disable=broad-exception-caught
model_path = cfg.base_model
model, _ = load_llama_model_4bit_low_ram(
base_model_config if base_model_config else base_model,
model_path,
device_map=cfg.device_map,
half=cfg.fp16,
groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
is_v1_model=cfg.gptq_model_v1
if cfg.gptq_model_v1 is not None
else True,
)
load_in_8bit = False
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM
config_kwargs = {}
@@ -273,15 +220,24 @@ def load_model(
# )
# model.train() # sets to train instead of eval mode
elif model_type and not cfg.trust_remote_code:
model = getattr(transformers, model_type).from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = getattr(transformers, model_type).from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
config = AutoConfig.from_pretrained(
base_model,
@@ -357,11 +313,12 @@ def load_model(
module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp
if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
@@ -383,22 +340,10 @@ def load_model(
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")
if cfg.gptq:
# Scales to half
LOG.info("Fitting 4bit scales and zeros to half")
for _, module in model.named_modules():
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
type(module)
):
if hasattr(module, "is_v1_model") and module.is_v1_model:
module.zeros = module.zeros.half()
module.scales = module.scales.half()
module.bias = module.bias.half()
if (
torch.cuda.device_count() > 1
and int(os.getenv("WORLD_SIZE", "1")) > 1
and (cfg.gptq or cfg.load_in_4bit)
and (cfg.load_in_4bit)
):
# llama is PROBABLY model parallelizable, but the default isn't that it is
# so let's only set it for the 4bit, see
@@ -424,15 +369,15 @@ def load_model(
return model, lora_config
def load_adapter(model, cfg, adapter):
# type: (PreTrainedModel, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
def load_adapter(model, cfg, adapter, inference=False):
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
return load_lora(model, cfg)
return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter":
return load_llama_adapter(model, cfg)
@@ -464,12 +409,8 @@ def load_llama_adapter(model, cfg):
return model, peft_config
def find_all_linear_names(bits, model):
cls = (
bnb.nn.Linear4bit
if bits == 4
else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
)
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
@@ -482,21 +423,15 @@ def find_all_linear_names(bits, model):
return list(lora_module_names)
def load_lora(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
def load_lora(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import LoraConfig, PeftModel, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or [])
if cfg.lora_target_linear:
bits = None
if cfg.load_in_4bit:
bits = 4
elif cfg.load_in_8bit:
bits = 8
linear_names = find_all_linear_names(bits, model)
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names))
@@ -516,7 +451,7 @@ def load_lora(model, cfg):
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
is_trainable=not cfg.inference,
is_trainable=(not inference),
)
else:
model = get_peft_model(model, lora_config)

View File

@@ -12,15 +12,9 @@ from typing import Optional, Union
import numpy as np
import torch.cuda
import transformers
from datasets import Dataset, set_caching_enabled
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import (
DataLoader,
DistributedSampler,
RandomSampler,
SequentialSampler,
)
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import SequentialDistributedSampler
@@ -29,7 +23,6 @@ from axolotl.utils.callbacks import (
GPUStatsCallback,
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
bench_eval_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
@@ -134,27 +127,6 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
class AxolotlTrainer(Trainer):
@@ -164,10 +136,6 @@ class AxolotlTrainer(Trainer):
args = None # type: AxolotlTrainingArguments
def __init__(self, *args, bench_data_collator=None, **kwargs):
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
@@ -258,31 +226,6 @@ class AxolotlTrainer(Trainer):
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> Union[DataLoader, MultipackDistributedDataloader]:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
@@ -504,23 +447,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
training_arguments_kwargs["seed"] = cfg.seed
if cfg.gradient_checkpointing:
if cfg.gptq:
from alpaca_lora_4bit.gradient_checkpointing import (
apply_gradient_checkpointing,
)
gradient_checkpointing_ratio = (
cfg.gradient_checkpointing_ratio
if cfg.gradient_checkpointing_ratio
else 1.0
)
apply_gradient_checkpointing(
model, checkpoint_ratio=gradient_checkpointing_ratio
)
else:
training_arguments_kwargs[
"gradient_checkpointing"
] = cfg.gradient_checkpointing
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
if cfg.fsdp:
training_arguments_kwargs["fsdp"] = cfg.fsdp
if cfg.fsdp_config:
@@ -574,11 +501,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"steps" if cfg.save_steps else "epoch"
)
if cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
@@ -647,10 +569,12 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = {
"padding": True,
"padding": True, # True/"longest" is the default
}
if cfg.collator_pad_to_longest:
data_collator_kwargs["padding"] = "longest"
if cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
cfg.sequence_len / 64
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
@@ -689,16 +613,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=callbacks,
**trainer_kwargs,
)
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
return trainer