Merge branch 'main' into patch-1

This commit is contained in:
Angainor Development
2023-06-10 19:07:54 +02:00
committed by GitHub
10 changed files with 1844 additions and 37 deletions

View File

@@ -22,7 +22,7 @@
| Pythia | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | | | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | | | ❌ | ❌ | ❌ | ❓ |
## Quickstart ⚡
@@ -33,6 +33,7 @@
git clone https://github.com/OpenAccess-AI-Collective/axolotl
pip3 install -e .
pip3 install -U git+https://github.com/huggingface/peft.git
accelerate config
@@ -53,6 +54,7 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.9-cu118-2.0.0
```
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0`: for runpod
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0-gptq`: for gptq
- `winglian/axolotl:dev`: dev branch (not usually up to date)
Or run on the current files for development:
@@ -67,9 +69,19 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
2. Install pytorch stable https://pytorch.org/get-started/locally/
3. Install python dependencies with ONE of the following:
- `pip3 install -e .` (recommended, supports QLoRA, no gptq/int4 support)
- `pip3 install -e .[gptq]` (next best if you don't need QLoRA, but want to use gptq)
- `pip3 install -e .[gptq_triton]`
- Recommended, supports QLoRA, NO gptq/int4 support
```bash
pip3 install -e .
pip3 install -U git+https://github.com/huggingface/peft.git
```
- gptq/int4 support, NO QLoRA
```bash
pip3 install -e .[gptq]
```
- same as above but not recommended
```bash
pip3 install -e .[gptq_triton]
```
- LambdaLabs
<details>
@@ -78,7 +90,8 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
1. Install python
```bash
sudo apt install python3.9
sudo apt update
sudo apt install -y python3.9
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
sudo update-alternatives --config python # pick 3.9 if given option
@@ -205,14 +218,18 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- custom prompts structure:
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type.
</details>
#### How to add custom prompts
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type.
Optionally, download some datasets, see [data/README.md](data/README.md)
### Config
See sample configs in [configs](configs) folder or [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
@@ -370,7 +387,7 @@ train_on_inputs: false
# don't use this, leads to wonky training (according to someone on the internet)
group_by_length: false
# does not work with current implementation of 4-bit LoRA
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# stop training after this many evaluation losses have increased in a row
@@ -400,6 +417,8 @@ flash_attention: # require a100 for llama
# whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
# Landmark attention (only llama)
landmark_attention:
# resume from a specific checkpoint dir
resume_from_checkpoint:

View File

@@ -0,0 +1,92 @@
# 1b: tiiuae/falcon-rw-1b
# 40b: tiiuae/falcon-40b
base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true
gptq: false
strict: false
push_dataset_to_hub:
datasets:
- path: QingyiSi/Alpaca-CoT
data_files:
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
type: "alpaca:chat"
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
# enable QLoRA
adapter: qlora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
# hyperparameters from QLoRA paper Appendix B.2
# "We find hyperparameters to be largely robust across datasets"
lora_r: 64
lora_alpha: 16
# 0.1 for models up to 13B
# 0.05 for 33B and 65B models
lora_dropout: 0.05
# add LoRA modules on all linear layers of the base model
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
# QLoRA paper Table 9
# - 16 for 7b & 13b
# - 32 for 33b, 64 for 64b
# Max size tested on A6000
# - 7b: 40
# - 40b: 4
# decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1
gradient_accumulation_steps: 2
num_epochs: 3
# Optimizer for QLoRA
optimizer: paged_adamw_32bit
torchdistx_path:
lr_scheduler: cosine
# QLoRA paper Table 9
# - 2e-4 for 7b & 13b
# - 1e-4 for 33b & 64b
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: true
gradient_checkpointing: true
# stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 5
save_steps: 10
debug:
deepspeed:
weight_decay: 0.000001
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
eos_token: "<|endoftext|>"

View File

@@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
import fire
import torch
import yaml
from transformers import GenerationConfig
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
@@ -64,13 +64,17 @@ def get_multi_line_input() -> Optional[str]:
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
tokenizer.add_special_tokens({"unk_token": "<unk>"})
tokenizer.add_special_tokens({"bos_token": "<s>"})
tokenizer.add_special_tokens({"eos_token": "</s>"})
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
@@ -79,7 +83,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
@@ -98,10 +102,13 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@@ -183,6 +190,9 @@ def train(
cfg.fp16 = True
cfg.bf16 = False
if cfg.tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# load the tokenizer first
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
logging.info(f"loading tokenizer... {tokenizer_config}")

File diff suppressed because it is too large Load Diff

View File

@@ -78,6 +78,13 @@ def load_tokenized_prepared_datasets(
else:
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
logging.info("Loading raw datasets...")
if cfg.seed:
seed = cfg.seed
else:
logging.info("No seed provided, using default seed of 42")
seed = 42
datasets = []
# pylint: disable=invalid-name
for d in cfg.datasets:
@@ -127,11 +134,11 @@ def load_tokenized_prepared_datasets(
# support for using a subset of the data
if d.shards:
if "train" in ds:
ds = ds.shuffle(seed=42)["train"].shard(
ds = ds.shuffle(seed=seed)["train"].shard(
num_shards=d.shards, index=0
)
else:
ds = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
d_type = d.type
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
@@ -239,7 +246,7 @@ def load_tokenized_prepared_datasets(
samples: List[int] = []
for d in datasets:
samples = samples + list(d)
dataset = Dataset.from_list(samples).shuffle(seed=42)
dataset = Dataset.from_list(samples).shuffle(seed=seed)
if cfg.local_rank == 0:
logging.info(
f"Saving merged prepared dataset to disk... {prepared_ds_path}"

View File

@@ -20,7 +20,9 @@ from transformers import ( # noqa: F401
)
try:
from transformers import LlamaForCausalLM
from transformers import ( # pylint: disable=unused-import # noqa: F401
LlamaForCausalLM,
)
except ImportError:
logging.warning(
"This version of transformers does not support Llama. Consider upgrading."
@@ -82,37 +84,47 @@ def load_model(
cfg,
adapter="lora"
):
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
"""
Load a model from a base model and a model type.
"""
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
is_llama_derived_model = "llama" in base_model or (
cfg.is_llama_derived_model = "llama" in base_model or (
cfg.model_type and "llama" in cfg.model_type.lower()
)
if is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and cfg.inference is False:
if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and inference is False:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
logging.info("patching with flash attention")
replace_llama_attn_with_flash_attn()
elif is_llama_derived_model and cfg.xformers_attention:
elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
logging.info("patching with xformers attention")
hijack_llama_attention()
elif is_llama_derived_model and cfg.sdp_attention:
elif cfg.is_llama_derived_model and cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_sdp_attention,
)
logging.info("patching with sdp attention")
hijack_llama_sdp_attention()
elif cfg.is_llama_derived_model and cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
MEM_TOKEN,
LlamaForCausalLM,
)
logging.info("patching with landmark attention")
# TODO: Check if this would overwrite previous additional_special_tokens
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
if cfg.bf16:
torch_dtype = torch.bfloat16
@@ -127,11 +139,18 @@ def load_model(
)
replace_peft_model_with_int4_lora_model()
from peft import prepare_model_for_int8_training
except Exception as err:
logging.exception(err)
raise err
try:
from peft import prepare_model_for_kbit_training
except ImportError:
# For backward compatibility
from peft import (
prepare_model_for_int8_training as prepare_model_for_kbit_training,
)
model_kwargs = {}
if cfg.adapter == "qlora" and cfg.load_in_4bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig(
@@ -143,7 +162,7 @@ def load_model(
bnb_4bit_quant_type="nf4",
)
try:
if cfg.gptq and is_llama_derived_model:
if cfg.gptq and cfg.is_llama_derived_model:
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
from huggingface_hub import snapshot_download
@@ -181,7 +200,7 @@ def load_model(
else True,
)
load_in_8bit = False
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
config = LlamaConfig.from_pretrained(base_model_config)
model = LlamaForCausalLM.from_pretrained(
base_model,
@@ -235,8 +254,15 @@ def load_model(
)
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if config.max_seq_len and cfg.sequence_len > config.max_seq_len:
if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len:
config.max_seq_len = cfg.sequence_len
logging.warning(f"increasing context length to {cfg.sequence_len}")
elif (
hasattr(config, "max_sequence_length")
and cfg.sequence_len > config.max_sequence_length
):
config.max_sequence_length = cfg.sequence_len
logging.warning(f"increasing context length to {cfg.sequence_len}")
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
@@ -268,8 +294,8 @@ def load_model(
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
):
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model)
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
model = prepare_model_for_kbit_training(model)
model, lora_config = load_adapter(model, cfg, adapter)

View File

@@ -1,6 +1,7 @@
"""Module containing the Trainer class and related functions"""
import importlib
import logging
import math
import os
import sys
@@ -62,8 +63,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
save_steps = cfg.save_steps
eval_steps = cfg.eval_steps
training_arguments_kwargs = {}
if cfg.bf16 == "full":
@@ -74,6 +73,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["tf32"] = cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if cfg.seed:
training_arguments_kwargs["seed"] = cfg.seed
if cfg.gradient_checkpointing:
if cfg.gptq:
from alpaca_lora_4bit.gradient_checkpointing import (
@@ -119,16 +122,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
save_strategy="steps" if save_steps else "epoch",
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
save_steps=save_steps,
save_strategy="steps" if cfg.save_steps else "epoch",
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
save_steps=cfg.save_steps,
output_dir=cfg.output_dir,
save_total_limit=3,
load_best_model_at_end=(
cfg.load_best_model_at_end is not False
and cfg.val_set_size > 0
and save_steps
and save_steps % eval_steps == 0
and cfg.save_steps
and cfg.save_steps % cfg.eval_steps == 0
and cfg.load_in_8bit is not True
)
or False,
@@ -233,6 +236,23 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
else:
data_collator_kwargs["pad_to_multiple_of"] = 8
if cfg.is_llama_derived_model and cfg.landmark_attention:
from functools import partial
from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
model.set_mem_id(mem_id)
logging.info("Adding landmark attention tokens to dataset")
for dataset in [train_dataset, eval_dataset]:
dataset = dataset.map(
partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
batched=False,
num_proc=32,
)
trainer_cls = (
OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")

View File

@@ -54,6 +54,9 @@ def validate_config(cfg):
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
raise ValueError("FSDP is not supported for falcon models")
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -15,3 +15,5 @@ def setup_wandb_env_vars(cfg):
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
else:
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -165,3 +165,36 @@ class ValidationTest(unittest.TestCase):
)
validate_config(cfg)
def test_falcon_fsdp(self):
regex_exp = r".*FSDP is not supported for falcon models.*"
# Check for lower-case
cfg = DictDefault(
{
"base_model": "tiiuae/falcon-7b",
"fsdp": ["full_shard", "auto_wrap"],
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
# Check for upper-case
cfg = DictDefault(
{
"base_model": "Falcon-7b",
"fsdp": ["full_shard", "auto_wrap"],
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
cfg = DictDefault(
{
"base_model": "tiiuae/falcon-7b",
}
)
validate_config(cfg)