Merge pull request #92 from OpenAccess-AI-Collective/flash-optimum

add support for opimum bettertransformers
This commit is contained in:
Wing Lian
2023-06-14 07:50:15 -04:00
committed by GitHub
11 changed files with 367 additions and 24 deletions

View File

@@ -421,6 +421,8 @@ optimizer:
# specify weight decay # specify weight decay
weight_decay: weight_decay:
# whether to bettertransformers
flash_optimum:
# whether to use xformers attention patch https://github.com/facebookresearch/xformers: # whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention: xformers_attention:
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention: # whether to use flash attention patch https://github.com/HazyResearch/flash-attention:

View File

@@ -0,0 +1,9 @@
# Pythia 12B
- Single-GPU A100 only (?)
```shell
python scripts/finetune.py examples/pythia-12b/config.yml
```
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️

View File

@@ -0,0 +1,49 @@
base_model: EleutherAI/pythia-12b-deduped
base_model_config: EleutherAI/pythia-12b-deduped
base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
gptq: false
device_map: auto
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 64
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./pythia-12b
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 5
learning_rate: 0.00003
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
train_on_inputs: false
group_by_length: false
bf16: false
fp16: false
float16: true
tf32: true
flash_optimum: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
gradient_checkpointing: true
fsdp:
fsdp_config:
collator_pad_to_longest: true

View File

@@ -11,6 +11,7 @@ sentencepiece
wandb wandb
einops einops
xformers xformers
optimum
# qlora things # qlora things
bert-score==0.3.13 bert-score==0.3.13
evaluate==0.4.0 evaluate==0.4.0

View File

@@ -12,13 +12,14 @@ from typing import Any, Dict, List, Optional, Union
import fire import fire
import torch import torch
import yaml import yaml
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config from axolotl.utils.validation import validate_config
@@ -217,9 +218,20 @@ def train(
if ( if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
): # don't need to load dataset for these ): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets( if not cfg.pretraining_dataset:
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH train_dataset, eval_dataset = load_prepare_datasets(
) tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
tokenizer,
max_tokens=cfg.sequence_len,
seed=cfg.seed,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if cfg.debug or "debug" in kwargs: if cfg.debug or "debug" in kwargs:
logging.info("check_dataset_labels...") logging.info("check_dataset_labels...")
@@ -285,12 +297,15 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0: if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
sys.exit(0)
signal.signal( signal.signal(
signal.SIGINT, signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
lambda signal, frame: (
model.save_pretrained(cfg.output_dir),
sys.exit(0),
),
) )
logging.info("Starting trainer...") logging.info("Starting trainer...")
@@ -313,13 +328,21 @@ def train(
if not Path(cfg.output_dir).is_dir(): if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True) os.makedirs(cfg.output_dir, exist_ok=True)
trainer.train(resume_from_checkpoint=resume_from_checkpoint) if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0: if cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir) model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time

View File

@@ -2,13 +2,14 @@
import os import os
from optimum.bettertransformer import BetterTransformer
from transformers import ( from transformers import (
TrainerCallback, TrainerCallback,
TrainerControl, TrainerControl,
TrainerState, TrainerState,
TrainingArguments, TrainingArguments,
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -30,3 +31,39 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
kwargs["model"].save_pretrained(peft_model_path) kwargs["model"].save_pretrained(peft_model_path)
return control return control
class SaveBetterTransformerModelCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods
"""Callback to save the BetterTransformer wrapped model"""
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
):
control.should_save = True
if control.should_save:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
# FIXME - need to cleanup old checkpoints
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control

View File

@@ -1,10 +1,11 @@
"""Module containing data utilities""" """Module containing data utilities"""
import functools
import logging import logging
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union from typing import List, Tuple, Union
import torch
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -394,8 +395,127 @@ def load_prepare_datasets(
index=cfg.dataset_shard_idx, index=cfg.dataset_shard_idx,
) )
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) if cfg.val_set_size:
train_dataset = dataset["train"] dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
eval_dataset = dataset["test"] train_dataset = dataset["train"]
eval_dataset = dataset["test"]
else:
train_dataset = dataset
eval_dataset = None
return train_dataset, eval_dataset return train_dataset, eval_dataset
def encode_pretraining(tokenizer, max_tokens, examples):
res = tokenizer(
examples["text"],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
)
# Convert to PyTorch tensors
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
new_input_ids = []
new_attention_mask = []
# Append EOS and PAD tokens to input_ids, and correct attention_mask
for i, _ in enumerate(input_ids):
input_ids[i] = torch.cat(
(
input_ids[i],
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
),
dim=0,
)
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
# Concatenate tokens so that their lengths are less than max_tokens
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
for ids, mask in zip(input_ids, attention_mask):
if buffer_input_ids.numel() == max_tokens:
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
else:
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
if buffer_input_ids.numel() > 0: # for any leftover tokens
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
ret = {
"input_ids": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_input_ids],
"attention_mask": [seq.tolist() for seq in new_attention_mask],
}
logging.debug(len(ret["input_ids"]))
return ret
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train")
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
# TODO dynamically figure out which columns/features to remove
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
return dataset

View File

@@ -10,8 +10,9 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
from optimum.bettertransformer import BetterTransformer
from transformers import PreTrainedModel # noqa: F401 from transformers import PreTrainedModel # noqa: F401
from transformers import ( # noqa: F401 from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
@@ -121,9 +122,9 @@ def load_model(
logging.info("patching with xpos rope") logging.info("patching with xpos rope")
replace_llama_rope_with_xpos_rope() replace_llama_rope_with_xpos_rope()
if cfg.bf16: if cfg.bf16 or cfg.bfloat16:
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16: elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
torch_dtype = torch.float16 torch_dtype = torch.float16
else: else:
torch_dtype = torch.float32 torch_dtype = torch.float32
@@ -287,6 +288,15 @@ def load_model(
embeddings_len = math.ceil(len(tokenizer) / 32) * 32 embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len) model.resize_token_embeddings(embeddings_len)
if (
hasattr(model.config, "max_position_embeddings")
and cfg.sequence_len >= model.config.max_position_embeddings
):
logging.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len
if not cfg.gptq and ( if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit) (cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -332,6 +342,9 @@ def load_model(
logging.warning("there are no parameters that require gradient updates") logging.warning("there are no parameters that require gradient updates")
model.config.use_cache = False model.config.use_cache = False
if cfg.flash_optimum:
model = BetterTransformer.transform(model)
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return model, lora_config return model, lora_config

View File

@@ -16,7 +16,10 @@ from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer from transformers import EarlyStoppingCallback, Trainer
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import SavePeftModelCallback from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
from axolotl.utils.schedulers import InterpolatingLogScheduler from axolotl.utils.schedulers import InterpolatingLogScheduler
@@ -228,6 +231,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
]: # only save in rank 0 ]: # only save in rank 0
callbacks.append(SavePeftModelCallback) callbacks.append(SavePeftModelCallback)
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = { data_collator_kwargs = {
"padding": True, "padding": True,
} }

View File

@@ -2,6 +2,8 @@
import logging import logging
import torch
def validate_config(cfg): def validate_config(cfg):
if cfg.gradient_accumulation_steps and cfg.batch_size: if cfg.gradient_accumulation_steps and cfg.batch_size:
@@ -62,7 +64,37 @@ def validate_config(cfg):
) and cfg.gradient_checkpointing: ) and cfg.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models") raise ValueError("gradient_checkpointing is not supported for MPT models")
if cfg.flash_optimum is True:
if cfg.adapter:
logging.warning(
"BetterTransformers probably doesn't work with PEFT adapters"
)
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True and cfg.bloat16 is not True:
logging.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".")[0]) < 2:
logging.warning("torch>=2.0.0 required")
raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
)
if cfg.pretraining_dataset and cfg.group_by_length:
logging.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25
# no 8bit adamw w bf16 # no 8bit adaAmw w bf16
# GPT-NeoX
# evals broken when extending context len
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
# attention_mask = causal_mask + attention_mask
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3

View File

@@ -212,3 +212,54 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=regex_exp): with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg) validate_config(cfg)
def test_flash_optimum(self):
cfg = DictDefault(
{
"flash_optimum": True,
"adapter": "lora",
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"BetterTransformers probably doesn't work with PEFT adapters"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"flash_optimum": True,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"probably set bfloat16 or float16" in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"flash_optimum": True,
"fp16": True,
}
)
regex_exp = r".*AMP is not supported.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
cfg = DictDefault(
{
"flash_optimum": True,
"bf16": True,
}
)
regex_exp = r".*AMP is not supported.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)