diff --git a/README.md b/README.md index 5a00cccac..e267a9d6d 100644 --- a/README.md +++ b/README.md @@ -421,6 +421,8 @@ optimizer: # specify weight decay weight_decay: +# whether to bettertransformers +flash_optimum: # whether to use xformers attention patch https://github.com/facebookresearch/xformers: xformers_attention: # whether to use flash attention patch https://github.com/HazyResearch/flash-attention: diff --git a/examples/pythia-12b/README.md b/examples/pythia-12b/README.md new file mode 100644 index 000000000..123ffa710 --- /dev/null +++ b/examples/pythia-12b/README.md @@ -0,0 +1,9 @@ +# Pythia 12B + +- Single-GPU A100 only (?) + +```shell +python scripts/finetune.py examples/pythia-12b/config.yml +``` + +⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️ diff --git a/examples/pythia-12b/config.yml b/examples/pythia-12b/config.yml new file mode 100644 index 000000000..3b3d91630 --- /dev/null +++ b/examples/pythia-12b/config.yml @@ -0,0 +1,49 @@ +base_model: EleutherAI/pythia-12b-deduped +base_model_config: EleutherAI/pythia-12b-deduped +base_model_ignore_patterns: pytorch* # prefer safetensors +model_type: GPTNeoXForCausalLM +tokenizer_type: AutoTokenizer +load_in_8bit: false +load_in_4bit: false +gptq: false +device_map: auto +datasets: + - path: vicgalle/alpaca-gpt4 + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.05 +adapter: +lora_model_dir: +sequence_len: 2048 +max_packed_sequence_len: 2048 +lora_r: 64 +lora_alpha: 32 +lora_dropout: 0.0 +lora_target_modules: +lora_target_linear: true +lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific +wandb_project: +wandb_watch: +wandb_run_id: +wandb_log_model: +output_dir: ./pythia-12b +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 5 +learning_rate: 0.00003 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +train_on_inputs: false +group_by_length: false +bf16: false +fp16: false +float16: true +tf32: true +flash_optimum: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +gradient_checkpointing: true +fsdp: +fsdp_config: +collator_pad_to_longest: true diff --git a/requirements.txt b/requirements.txt index c9123fce8..d1b2f4555 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ sentencepiece wandb einops xformers +optimum # qlora things bert-score==0.3.13 evaluate==0.4.0 diff --git a/scripts/finetune.py b/scripts/finetune.py index 785f3cf23..a1c5b13b9 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -12,13 +12,14 @@ from typing import Any, Dict, List, Optional, Union import fire import torch import yaml -from transformers import GenerationConfig, TextStreamer - -from axolotl.utils.data import load_prepare_datasets -from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_tokenizer # add src to the pythonpath so we don't need to pip install this +from optimum.bettertransformer import BetterTransformer +from transformers import GenerationConfig, TextStreamer + +from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.trainer import setup_trainer from axolotl.utils.validation import validate_config @@ -217,9 +218,20 @@ def train( if ( check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference ): # don't need to load dataset for these - train_dataset, eval_dataset = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH - ) + if not cfg.pretraining_dataset: + train_dataset, eval_dataset = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH + ) + else: + train_dataset = load_pretraining_dataset( + cfg.pretraining_dataset, + tokenizer, + max_tokens=cfg.sequence_len, + seed=cfg.seed, + ) + # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 + train_dataset = train_dataset.with_format("torch") + eval_dataset = None if cfg.debug or "debug" in kwargs: logging.info("check_dataset_labels...") @@ -285,12 +297,15 @@ def train( # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model if cfg.local_rank == 0: + + def terminate_handler(_, __, model): + if cfg.flash_optimum: + model = BetterTransformer.reverse(model) + model.save_pretrained(cfg.output_dir) + sys.exit(0) + signal.signal( - signal.SIGINT, - lambda signal, frame: ( - model.save_pretrained(cfg.output_dir), - sys.exit(0), - ), + signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) ) logging.info("Starting trainer...") @@ -313,13 +328,21 @@ def train( if not Path(cfg.output_dir).is_dir(): os.makedirs(cfg.output_dir, exist_ok=True) - trainer.train(resume_from_checkpoint=resume_from_checkpoint) + if cfg.flash_optimum: + with torch.backends.cuda.sdp_kernel( + enable_flash=True, enable_math=True, enable_mem_efficient=True + ): + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + else: + trainer.train(resume_from_checkpoint=resume_from_checkpoint) logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.local_rank == 0: + if cfg.flash_optimum: + model = BetterTransformer.reverse(model) model.save_pretrained(cfg.output_dir) # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index f6852249a..526121f2e 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -2,13 +2,14 @@ import os +from optimum.bettertransformer import BetterTransformer from transformers import ( TrainerCallback, TrainerControl, TrainerState, TrainingArguments, ) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods @@ -30,3 +31,39 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public- kwargs["model"].save_pretrained(peft_model_path) return control + + +class SaveBetterTransformerModelCallback( + TrainerCallback +): # pylint: disable=too-few-public-methods + """Callback to save the BetterTransformer wrapped model""" + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Save + if ( + args.save_strategy == IntervalStrategy.STEPS + and args.save_steps > 0 + and state.global_step % args.save_steps == 0 + ): + control.should_save = True + + if control.should_save: + checkpoint_folder = os.path.join( + args.output_dir, + f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", + ) + + model = BetterTransformer.reverse(kwargs["model"]) + model.save_pretrained(checkpoint_folder) + # FIXME - need to cleanup old checkpoints + + # since we're saving here, we don't need the trainer loop to attempt to save too b/c + # the trainer will raise an exception since it can't save a BetterTransformer wrapped model + control.should_save = False + return control diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 9fee2fb9b..c36bfcee9 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,10 +1,11 @@ """Module containing data utilities""" - +import functools import logging from hashlib import md5 from pathlib import Path from typing import List, Tuple, Union +import torch from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from huggingface_hub import hf_hub_download from transformers import PreTrainedTokenizerBase @@ -394,8 +395,127 @@ def load_prepare_datasets( index=cfg.dataset_shard_idx, ) - dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) - train_dataset = dataset["train"] - eval_dataset = dataset["test"] + if cfg.val_set_size: + dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) + train_dataset = dataset["train"] + eval_dataset = dataset["test"] + else: + train_dataset = dataset + eval_dataset = None return train_dataset, eval_dataset + + +def encode_pretraining(tokenizer, max_tokens, examples): + res = tokenizer( + examples["text"], + truncation=True, + max_length=max_tokens - 2, + add_special_tokens=True, + ) + # Convert to PyTorch tensors + input_ids = [torch.tensor(seq) for seq in res["input_ids"]] + attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] + new_input_ids = [] + new_attention_mask = [] + # Append EOS and PAD tokens to input_ids, and correct attention_mask + for i, _ in enumerate(input_ids): + input_ids[i] = torch.cat( + ( + input_ids[i], + torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]), + ), + dim=0, + ) + attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0) + + # Concatenate tokens so that their lengths are less than max_tokens + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) + + for ids, mask in zip(input_ids, attention_mask): + if buffer_input_ids.numel() == max_tokens: + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + elif buffer_input_ids.numel() + ids.numel() <= max_tokens: + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + else: + buffer_input_ids = torch.cat( + ( + buffer_input_ids, + torch.full( + (max_tokens - buffer_input_ids.numel(),), + tokenizer.pad_token_id, + dtype=torch.long, + ), + ), + dim=0, + ) + buffer_attention_mask = torch.cat( + ( + buffer_attention_mask, + torch.full( + (max_tokens - buffer_attention_mask.numel(),), + 0, + dtype=torch.long, + ), + ), + dim=0, + ) + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) + + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + + if buffer_input_ids.numel() > 0: # for any leftover tokens + while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size + buffer_input_ids = torch.cat( + ( + buffer_input_ids, + torch.full( + (max_tokens - buffer_input_ids.numel(),), + tokenizer.pad_token_id, + dtype=torch.long, + ), + ), + dim=0, + ) + buffer_attention_mask = torch.cat( + ( + buffer_attention_mask, + torch.full( + (max_tokens - buffer_attention_mask.numel(),), + 0, + dtype=torch.long, + ), + ), + dim=0, + ) + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + + ret = { + "input_ids": [seq.tolist() for seq in new_input_ids], + "labels": [seq.tolist() for seq in new_input_ids], + "attention_mask": [seq.tolist() for seq in new_attention_mask], + } + + logging.debug(len(ret["input_ids"])) + return ret + + +def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): + encode = functools.partial(encode_pretraining, tokenizer, max_tokens) + dataset = load_dataset(path, streaming=True, split="train") + dataset = dataset.shuffle(seed=seed, buffer_size=10_000) + # TODO dynamically figure out which columns/features to remove + dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"]) + return dataset diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d9bff4b14..05acfce93 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -10,8 +10,9 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401 import bitsandbytes as bnb import torch import transformers +from optimum.bettertransformer import BetterTransformer from transformers import PreTrainedModel # noqa: F401 -from transformers import ( # noqa: F401 +from transformers import ( AutoConfig, AutoModelForCausalLM, AutoTokenizer, @@ -121,9 +122,9 @@ def load_model( logging.info("patching with xpos rope") replace_llama_rope_with_xpos_rope() - if cfg.bf16: + if cfg.bf16 or cfg.bfloat16: torch_dtype = torch.bfloat16 - elif cfg.load_in_8bit or cfg.fp16: + elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: torch_dtype = torch.float16 else: torch_dtype = torch.float32 @@ -287,6 +288,15 @@ def load_model( embeddings_len = math.ceil(len(tokenizer) / 32) * 32 model.resize_token_embeddings(embeddings_len) + if ( + hasattr(model.config, "max_position_embeddings") + and cfg.sequence_len >= model.config.max_position_embeddings + ): + logging.warning( + f"increasing model.config.max_position_embeddings to {cfg.sequence_len}" + ) + model.config.max_position_embeddings = cfg.sequence_len + if not cfg.gptq and ( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) @@ -332,6 +342,9 @@ def load_model( logging.warning("there are no parameters that require gradient updates") model.config.use_cache = False + if cfg.flash_optimum: + model = BetterTransformer.transform(model) + # TODO resume_from_checkpoint handling return model, lora_config diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 1250ad4f6..5152e649b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -16,7 +16,10 @@ from torch.optim.lr_scheduler import OneCycleLR from transformers import EarlyStoppingCallback, Trainer from transformers.trainer_pt_utils import get_parameter_names -from axolotl.utils.callbacks import SavePeftModelCallback +from axolotl.utils.callbacks import ( + SaveBetterTransformerModelCallback, + SavePeftModelCallback, +) from axolotl.utils.schedulers import InterpolatingLogScheduler @@ -228,6 +231,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ]: # only save in rank 0 callbacks.append(SavePeftModelCallback) + if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: + callbacks.append(SaveBetterTransformerModelCallback) + data_collator_kwargs = { "padding": True, } diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index e2d0b34b1..298d36c4e 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -2,6 +2,8 @@ import logging +import torch + def validate_config(cfg): if cfg.gradient_accumulation_steps and cfg.batch_size: @@ -62,7 +64,37 @@ def validate_config(cfg): ) and cfg.gradient_checkpointing: raise ValueError("gradient_checkpointing is not supported for MPT models") + if cfg.flash_optimum is True: + if cfg.adapter: + logging.warning( + "BetterTransformers probably doesn't work with PEFT adapters" + ) + if cfg.fp16 or cfg.bf16: + raise ValueError("AMP is not supported with BetterTransformer") + if cfg.float16 is not True and cfg.bloat16 is not True: + logging.warning( + "You should probably set bfloat16 or float16 to true to " + "load the model in float16 for BetterTransformers" + ) + if int(torch.__version__.split(".")[0]) < 2: + logging.warning("torch>=2.0.0 required") + raise ValueError( + f"flash_optimum for BetterTransformers may not be used with {torch.__version__}" + ) + + if cfg.pretraining_dataset and cfg.group_by_length: + logging.warning( + "You probably want to disable group_by_length as it will force a streamed dataset to download completely." + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 - # no 8bit adamw w bf16 + # no 8bit adaAmw w bf16 + + # GPT-NeoX + # evals broken when extending context len + # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product + # attention_mask = causal_mask + attention_mask + # RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3 diff --git a/tests/test_validation.py b/tests/test_validation.py index e28891060..dba54586e 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -212,3 +212,54 @@ class ValidationTest(unittest.TestCase): with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) + + def test_flash_optimum(self): + cfg = DictDefault( + { + "flash_optimum": True, + "adapter": "lora", + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "BetterTransformers probably doesn't work with PEFT adapters" + in record.message + for record in self._caplog.records + ) + + cfg = DictDefault( + { + "flash_optimum": True, + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "probably set bfloat16 or float16" in record.message + for record in self._caplog.records + ) + + cfg = DictDefault( + { + "flash_optimum": True, + "fp16": True, + } + ) + regex_exp = r".*AMP is not supported.*" + + with pytest.raises(ValueError, match=regex_exp): + validate_config(cfg) + + cfg = DictDefault( + { + "flash_optimum": True, + "bf16": True, + } + ) + regex_exp = r".*AMP is not supported.*" + + with pytest.raises(ValueError, match=regex_exp): + validate_config(cfg)