Compare commits

...

8 Commits

Author SHA1 Message Date
Wing Lian
6fcb73faaa more gpt-neox long ctx fixes
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-06-01 08:20:08 -04:00
Wing Lian
a32cc1d021 fix bettertransformers save, force it to skip after saving correctly in callback 2023-06-01 00:33:13 -04:00
Wing Lian
86bd9fcff4 more tweaks to do pre-training with bettertransformers 2023-05-31 21:59:15 -04:00
Wing Lian
ed7531abb8 experimental expansion of ctx len 2023-05-31 16:51:19 -04:00
Wing Lian
bdb547b830 add validation/warning for bettertransformers and torch version 2023-05-31 16:41:24 -04:00
Wing Lian
8a37b43678 use pythia-12b, neox-20b is flaky 2023-05-31 16:41:21 -04:00
Wing Lian
28acebac36 add flash attn context for efficient training and attempt setting model to train mode: 2023-05-31 16:40:38 -04:00
Wing Lian
adea682316 add support for opimum bettertransformers 2023-05-31 16:39:35 -04:00
9 changed files with 222 additions and 43 deletions

View File

@@ -0,0 +1,10 @@
# Python 12B
- Single-GPU A100 only (?)
```shell
python scripts/finetune.py examples/pythia-12b/config.yml
```
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️

View File

@@ -1,39 +1,49 @@
base_model: EleutherAI/gpt-neox-20b
base_model: EleutherAI/pythia-12b-deduped
base_model_config: EleutherAI/pythia-12b-deduped
base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_8bit: false
load_in_4bit: false
gptq: false
device_map: auto
datasets:
- path: nomic-ai/gpt4all-j-prompt-generations
- path: vicgalle/alpaca-gpt4
type: alpaca
shards: 4
shards_index: 0
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter: lora
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 8
lora_r: 64
lora_alpha: 32
lora_dropout: 0.05
lora_dropout: 0.0
lora_target_modules:
- query_key_value
lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: gpt4all-neox-20b
wandb_project: pythia-12b
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./gpt4all-neox-20b
output_dir: ./pythia-12b
gradient_accumulation_steps: 1
micro_batch_size: 4
micro_batch_size: 1
num_epochs: 5
learning_rate: 0.00003
lr_scheduler: one_cycle
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
train_on_inputs: false
group_by_length: false
bf16: True
tf32: True
bf16: false
fp16: false
float16: true
tf32: true
flash_optimum: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
gradient_checkpointing: true
fsdp:
fsdp_transformer_layer_cls_to_wrap:
collator_pad_to_longest: true

View File

@@ -11,6 +11,7 @@ sentencepiece
wandb
einops
xformers
optimum
# qlora things
bert-score==0.3.13
evaluate==0.4.0

View File

@@ -12,13 +12,15 @@ from typing import Any, Dict, List, Optional, Union
import fire
import torch
import yaml
from transformers import GenerationConfig
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config
@@ -47,10 +49,11 @@ def choose_device(cfg):
return "cpu"
cfg.device = get_device()
if cfg.device == "cuda":
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
if cfg.device_map != "auto":
if cfg.device == "cuda":
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]:
@@ -190,9 +193,20 @@ def train(
if check_not_in(
["inference", "shard", "merge_lora"], kwargs
): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if not cfg.pretraining_dataset:
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
if cfg.pretraining_dataset is True:
pretraining_dataset = "togethercomputer/RedPajama-Data-1T"
else:
pretraining_dataset = cfg.pretraining_dataset
train_dataset = load_pretraining_dataset(
pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
)
train_dataset = Dataset.from_list(list(train_dataset))
eval_dataset = None
if cfg.debug or "debug" in kwargs:
logging.info("check_dataset_labels...")
@@ -238,6 +252,21 @@ def train(
model.save_pretrained(cfg.output_dir)
return
if cfg.debug:
logging.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec
),
tokenizer,
)
if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...")
return
model.train()
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
model.config.use_cache = False
@@ -253,12 +282,15 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
sys.exit(0)
signal.signal(
signal.SIGINT,
lambda signal, frame: (
model.save_pretrained(cfg.output_dir),
sys.exit(0),
),
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
logging.info("Starting trainer...")
@@ -278,13 +310,22 @@ def train(
logging.info(
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time

View File

@@ -2,13 +2,14 @@
import os
from optimum.bettertransformer import BetterTransformer
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -30,3 +31,39 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
kwargs["model"].save_pretrained(peft_model_path)
return control
class SaveBetterTransformerModelCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods
"""Callback to save the BetterTransformer wrapped model"""
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
):
control.should_save = True
if control.should_save:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
# FIXME - need to cleanup old checkpoints
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control

View File

@@ -5,7 +5,8 @@ from hashlib import md5
from pathlib import Path
from typing import List, Tuple, Union
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
import torch
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase
@@ -380,8 +381,43 @@ def load_prepare_datasets(
index=cfg.dataset_shard_idx,
)
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
if cfg.val_set_size:
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
else:
train_dataset = dataset
eval_dataset = None
return train_dataset, eval_dataset
class PretrainingDatasetWrapper(IterableDataset):
"""
Wrapper for pretraining dataset that avoids loading the dataset into memory
"""
def __init__(self, tokenizer, dataset_path, max_tokens=2048):
self.tokenizer = tokenizer
self.dataset_path = dataset_path
self.max_tokens = max_tokens
def __iter__(self):
buffer = []
for sample in load_dataset(
self.dataset_path,
)["train"].shuffle():
buffer += self.tokenizer(sample["text"])["input_ids"]
buffer += [self.tokenizer.eos_token_id]
while len(buffer) > self.max_tokens:
input_ids = torch.tensor(buffer[: self.max_tokens])
yield {
"input_ids": input_ids,
"attention_mask": torch.ones(input_ids.size()),
"labels": input_ids,
}
buffer = buffer[self.max_tokens :]
def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens)

View File

@@ -10,8 +10,9 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb
import torch
import transformers
from optimum.bettertransformer import BetterTransformer
from transformers import PreTrainedModel # noqa: F401
from transformers import ( # noqa: F401
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
@@ -115,9 +116,9 @@ def load_model(
logging.info("patching with sdp attention")
hijack_llama_sdp_attention()
if cfg.bf16:
if cfg.bf16 or cfg.bfloat16:
torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16:
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
@@ -261,6 +262,12 @@ def load_model(
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len)
if cfg.sequence_len >= model.config.max_position_embeddings:
logging.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len
if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -304,6 +311,9 @@ def load_model(
logging.warning("there are no parameters that require gradient updates")
model.config.use_cache = False
if cfg.flash_optimum:
model = BetterTransformer.transform(model)
# TODO resume_from_checkpoint handling
return model, lora_config

View File

@@ -1,6 +1,7 @@
"""Module containing the Trainer class and related functions"""
import importlib
import logging
import math
import os
import sys
@@ -15,7 +16,10 @@ from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import SavePeftModelCallback
from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
from axolotl.utils.schedulers import InterpolatingLogScheduler
@@ -225,6 +229,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
]: # only save in rank 0
callbacks.append(SavePeftModelCallback)
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
logging.info("Setting up SaveBetterTransformerModelCallback.")
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = {
"padding": True,
}

View File

@@ -2,6 +2,8 @@
import logging
import torch
def validate_config(cfg):
if cfg.gradient_accumulation_steps and cfg.batch_size:
@@ -48,7 +50,31 @@ def validate_config(cfg):
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
if cfg.flash_optimum is True:
if cfg.adapter:
logging.warning(
"BetterTransformers probably doesn't work with PEFT adapters"
)
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True and cfg.bloat16 is not True:
logging.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".")[0]) < 2:
logging.warning("torch>=2.0.0 required")
raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
# no 8bit adamw w bf16
# no 8bit adaAmw w bf16
# GPT-NeoX
# evals broken when extending context len
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
# attention_mask = causal_mask + attention_mask
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3