Compare commits

...

8 Commits

Author SHA1 Message Date
Wing Lian
6fcb73faaa more gpt-neox long ctx fixes
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-06-01 08:20:08 -04:00
Wing Lian
a32cc1d021 fix bettertransformers save, force it to skip after saving correctly in callback 2023-06-01 00:33:13 -04:00
Wing Lian
86bd9fcff4 more tweaks to do pre-training with bettertransformers 2023-05-31 21:59:15 -04:00
Wing Lian
ed7531abb8 experimental expansion of ctx len 2023-05-31 16:51:19 -04:00
Wing Lian
bdb547b830 add validation/warning for bettertransformers and torch version 2023-05-31 16:41:24 -04:00
Wing Lian
8a37b43678 use pythia-12b, neox-20b is flaky 2023-05-31 16:41:21 -04:00
Wing Lian
28acebac36 add flash attn context for efficient training and attempt setting model to train mode: 2023-05-31 16:40:38 -04:00
Wing Lian
adea682316 add support for opimum bettertransformers 2023-05-31 16:39:35 -04:00
9 changed files with 222 additions and 43 deletions

View File

@@ -0,0 +1,10 @@
# Python 12B
- Single-GPU A100 only (?)
```shell
python scripts/finetune.py examples/pythia-12b/config.yml
```
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️

View File

@@ -1,39 +1,49 @@
base_model: EleutherAI/gpt-neox-20b base_model: EleutherAI/pythia-12b-deduped
base_model_config: EleutherAI/pythia-12b-deduped
base_model_ignore_patterns: pytorch* # prefer safetensors base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
load_in_8bit: true load_in_8bit: false
load_in_4bit: false
gptq: false
device_map: auto
datasets: datasets:
- path: nomic-ai/gpt4all-j-prompt-generations - path: vicgalle/alpaca-gpt4
type: alpaca type: alpaca
shards: 4
shards_index: 0
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
adapter: lora adapter:
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: 2048 max_packed_sequence_len: 2048
lora_r: 8 lora_r: 64
lora_alpha: 32 lora_alpha: 32
lora_dropout: 0.05 lora_dropout: 0.0
lora_target_modules: lora_target_modules:
- query_key_value lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: gpt4all-neox-20b wandb_project: pythia-12b
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./gpt4all-neox-20b output_dir: ./pythia-12b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 4 micro_batch_size: 1
num_epochs: 5 num_epochs: 5
learning_rate: 0.00003 learning_rate: 0.00003
lr_scheduler: one_cycle optimizer: adamw_bnb_8bit
lr_scheduler: cosine
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
bf16: True bf16: false
tf32: True fp16: false
float16: true
tf32: true
flash_optimum: true
early_stopping_patience: early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
gradient_checkpointing: true
fsdp:
fsdp_transformer_layer_cls_to_wrap:
collator_pad_to_longest: true

View File

@@ -11,6 +11,7 @@ sentencepiece
wandb wandb
einops einops
xformers xformers
optimum
# qlora things # qlora things
bert-score==0.3.13 bert-score==0.3.13
evaluate==0.4.0 evaluate==0.4.0

View File

@@ -12,13 +12,15 @@ from typing import Any, Dict, List, Optional, Union
import fire import fire
import torch import torch
import yaml import yaml
from transformers import GenerationConfig
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config from axolotl.utils.validation import validate_config
@@ -47,6 +49,7 @@ def choose_device(cfg):
return "cpu" return "cpu"
cfg.device = get_device() cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.device == "cuda": if cfg.device == "cuda":
cfg.device_map = {"": cfg.local_rank} cfg.device_map = {"": cfg.local_rank}
else: else:
@@ -190,9 +193,20 @@ def train(
if check_not_in( if check_not_in(
["inference", "shard", "merge_lora"], kwargs ["inference", "shard", "merge_lora"], kwargs
): # don't need to load dataset for these ): # don't need to load dataset for these
if not cfg.pretraining_dataset:
train_dataset, eval_dataset = load_prepare_datasets( train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
else:
if cfg.pretraining_dataset is True:
pretraining_dataset = "togethercomputer/RedPajama-Data-1T"
else:
pretraining_dataset = cfg.pretraining_dataset
train_dataset = load_pretraining_dataset(
pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
)
train_dataset = Dataset.from_list(list(train_dataset))
eval_dataset = None
if cfg.debug or "debug" in kwargs: if cfg.debug or "debug" in kwargs:
logging.info("check_dataset_labels...") logging.info("check_dataset_labels...")
@@ -238,6 +252,21 @@ def train(
model.save_pretrained(cfg.output_dir) model.save_pretrained(cfg.output_dir)
return return
if cfg.debug:
logging.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec
),
tokenizer,
)
if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...")
return
model.train()
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
model.config.use_cache = False model.config.use_cache = False
@@ -253,12 +282,15 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0: if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
sys.exit(0)
signal.signal( signal.signal(
signal.SIGINT, signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
lambda signal, frame: (
model.save_pretrained(cfg.output_dir),
sys.exit(0),
),
) )
logging.info("Starting trainer...") logging.info("Starting trainer...")
@@ -278,6 +310,13 @@ def train(
logging.info( logging.info(
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}" f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
) )
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
@@ -285,6 +324,8 @@ def train(
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0: if cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir) model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time

View File

@@ -2,13 +2,14 @@
import os import os
from optimum.bettertransformer import BetterTransformer
from transformers import ( from transformers import (
TrainerCallback, TrainerCallback,
TrainerControl, TrainerControl,
TrainerState, TrainerState,
TrainingArguments, TrainingArguments,
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -30,3 +31,39 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
kwargs["model"].save_pretrained(peft_model_path) kwargs["model"].save_pretrained(peft_model_path)
return control return control
class SaveBetterTransformerModelCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods
"""Callback to save the BetterTransformer wrapped model"""
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
):
control.should_save = True
if control.should_save:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
# FIXME - need to cleanup old checkpoints
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control

View File

@@ -5,7 +5,8 @@ from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union from typing import List, Tuple, Union
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk import torch
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -380,8 +381,43 @@ def load_prepare_datasets(
index=cfg.dataset_shard_idx, index=cfg.dataset_shard_idx,
) )
if cfg.val_set_size:
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] eval_dataset = dataset["test"]
else:
train_dataset = dataset
eval_dataset = None
return train_dataset, eval_dataset return train_dataset, eval_dataset
class PretrainingDatasetWrapper(IterableDataset):
"""
Wrapper for pretraining dataset that avoids loading the dataset into memory
"""
def __init__(self, tokenizer, dataset_path, max_tokens=2048):
self.tokenizer = tokenizer
self.dataset_path = dataset_path
self.max_tokens = max_tokens
def __iter__(self):
buffer = []
for sample in load_dataset(
self.dataset_path,
)["train"].shuffle():
buffer += self.tokenizer(sample["text"])["input_ids"]
buffer += [self.tokenizer.eos_token_id]
while len(buffer) > self.max_tokens:
input_ids = torch.tensor(buffer[: self.max_tokens])
yield {
"input_ids": input_ids,
"attention_mask": torch.ones(input_ids.size()),
"labels": input_ids,
}
buffer = buffer[self.max_tokens :]
def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens)

View File

@@ -10,8 +10,9 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
from optimum.bettertransformer import BetterTransformer
from transformers import PreTrainedModel # noqa: F401 from transformers import PreTrainedModel # noqa: F401
from transformers import ( # noqa: F401 from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
@@ -115,9 +116,9 @@ def load_model(
logging.info("patching with sdp attention") logging.info("patching with sdp attention")
hijack_llama_sdp_attention() hijack_llama_sdp_attention()
if cfg.bf16: if cfg.bf16 or cfg.bfloat16:
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16: elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
torch_dtype = torch.float16 torch_dtype = torch.float16
else: else:
torch_dtype = torch.float32 torch_dtype = torch.float32
@@ -261,6 +262,12 @@ def load_model(
embeddings_len = math.ceil(len(tokenizer) / 32) * 32 embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len) model.resize_token_embeddings(embeddings_len)
if cfg.sequence_len >= model.config.max_position_embeddings:
logging.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len
if not cfg.gptq and ( if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit) (cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -304,6 +311,9 @@ def load_model(
logging.warning("there are no parameters that require gradient updates") logging.warning("there are no parameters that require gradient updates")
model.config.use_cache = False model.config.use_cache = False
if cfg.flash_optimum:
model = BetterTransformer.transform(model)
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return model, lora_config return model, lora_config

View File

@@ -1,6 +1,7 @@
"""Module containing the Trainer class and related functions""" """Module containing the Trainer class and related functions"""
import importlib import importlib
import logging
import math import math
import os import os
import sys import sys
@@ -15,7 +16,10 @@ from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer from transformers import EarlyStoppingCallback, Trainer
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import SavePeftModelCallback from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
from axolotl.utils.schedulers import InterpolatingLogScheduler from axolotl.utils.schedulers import InterpolatingLogScheduler
@@ -225,6 +229,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
]: # only save in rank 0 ]: # only save in rank 0
callbacks.append(SavePeftModelCallback) callbacks.append(SavePeftModelCallback)
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
logging.info("Setting up SaveBetterTransformerModelCallback.")
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = { data_collator_kwargs = {
"padding": True, "padding": True,
} }

View File

@@ -2,6 +2,8 @@
import logging import logging
import torch
def validate_config(cfg): def validate_config(cfg):
if cfg.gradient_accumulation_steps and cfg.batch_size: if cfg.gradient_accumulation_steps and cfg.batch_size:
@@ -48,7 +50,31 @@ def validate_config(cfg):
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
) )
if cfg.flash_optimum is True:
if cfg.adapter:
logging.warning(
"BetterTransformers probably doesn't work with PEFT adapters"
)
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True and cfg.bloat16 is not True:
logging.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".")[0]) < 2:
logging.warning("torch>=2.0.0 required")
raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25
# no 8bit adamw w bf16 # no 8bit adaAmw w bf16
# GPT-NeoX
# evals broken when extending context len
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
# attention_mask = causal_mask + attention_mask
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3