From 86bd9fcff4deb724058ded70111ecca46dc6487c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 31 May 2023 21:59:15 -0400 Subject: [PATCH] more tweaks to do pre-training with bettertransformers --- scripts/finetune.py | 2 ++ src/axolotl/utils/callbacks.py | 24 ++++++++++++++++++++++++ src/axolotl/utils/data.py | 12 +++++++----- src/axolotl/utils/models.py | 4 ++-- src/axolotl/utils/trainer.py | 8 +++++++- src/axolotl/utils/validation.py | 16 ++++++++++++---- 6 files changed, 54 insertions(+), 12 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 21d0044be..b17631894 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -14,6 +14,7 @@ import torch import yaml # add src to the pythonpath so we don't need to pip install this +from datasets import Dataset from optimum.bettertransformer import BetterTransformer from transformers import GenerationConfig @@ -204,6 +205,7 @@ def train( train_dataset = load_pretraining_dataset( pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len ) + train_dataset = Dataset.from_list(list(train_dataset)) eval_dataset = None if cfg.debug or "debug" in kwargs: diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index f6852249a..ab197304c 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -2,6 +2,7 @@ import os +from optimum.bettertransformer import BetterTransformer from transformers import ( TrainerCallback, TrainerControl, @@ -30,3 +31,26 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public- kwargs["model"].save_pretrained(peft_model_path) return control + + +class SaveBetterTransformerModelCallback( + TrainerCallback +): # pylint: disable=too-few-public-methods + """Callback to save the BatterTransformer wrapped model""" + + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + checkpoint_folder = os.path.join( + args.output_dir, + f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", + ) + + model = BetterTransformer.reverse(kwargs["model"]) + model.save_pretrained(checkpoint_folder) + + return control diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 4566c09fb..2f5bc66fd 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -402,14 +402,16 @@ class PretrainingDatasetWrapper(IterableDataset): buffer = [] for sample in load_dataset( self.dataset_path, - name="all", - split="train", - streaming=True, - ).shuffle(buffer_size=10000): + )["train"].shuffle(): buffer += self.tokenizer(sample["text"])["input_ids"] buffer += [self.tokenizer.eos_token_id] while len(buffer) > self.max_tokens: - yield torch.tensor(buffer[: self.max_tokens]) + input_ids = torch.tensor(buffer[: self.max_tokens]) + yield { + "input_ids": input_ids, + "attention_mask": torch.ones(input_ids.size()), + "labels": input_ids, + } buffer = buffer[self.max_tokens :] diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b3d46ce1c..7d8a17ed2 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -10,8 +10,8 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401 import bitsandbytes as bnb import torch import transformers -from transformers import PreTrainedModel # noqa: F401 from optimum.bettertransformer import BetterTransformer +from transformers import PreTrainedModel # noqa: F401 from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -116,7 +116,7 @@ def load_model( logging.info("patching with sdp attention") hijack_llama_sdp_attention() - if cfg.bf16: + if cfg.bf16 or cfg.bfloat16: torch_dtype = torch.bfloat16 elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: torch_dtype = torch.float16 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2986c491b..a530f34fc 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -15,7 +15,10 @@ from torch.optim.lr_scheduler import OneCycleLR from transformers import EarlyStoppingCallback, Trainer from transformers.trainer_pt_utils import get_parameter_names -from axolotl.utils.callbacks import SavePeftModelCallback +from axolotl.utils.callbacks import ( + SaveBetterTransformerModelCallback, + SavePeftModelCallback, +) from axolotl.utils.schedulers import InterpolatingLogScheduler @@ -225,6 +228,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ]: # only save in rank 0 callbacks.append(SavePeftModelCallback) + if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: + callbacks.append(SaveBetterTransformerModelCallback) + data_collator_kwargs = { "padding": True, } diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 2b4fb47a4..11e9afe45 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -1,8 +1,10 @@ """Module for validating config files""" import logging + import torch + def validate_config(cfg): if cfg.gradient_accumulation_steps and cfg.batch_size: raise ValueError( @@ -50,14 +52,20 @@ def validate_config(cfg): if cfg.flash_optimum is True: if cfg.adapter: - logging.warning("BetterTransformers probably doesn't work with PEFT adapters") + logging.warning( + "BetterTransformers probably doesn't work with PEFT adapters" + ) if cfg.fp16 or cfg.bf16: raise ValueError("AMP is not supported with BetterTransformer") if cfg.float16 is not True: - logging.warning("You should probably set float16 to true to load the model in float16 for BetterTransformers") - if torch.__version__.split(".")[0] < 2: + logging.warning( + "You should probably set float16 to true to load the model in float16 for BetterTransformers" + ) + if int(torch.__version__.split(".")[0]) < 2: logging.warning("torch>=2.0.0 required") - raise ValueError(f"flash_optimum for BetterTransformers may not be used with {torch.__version__}") + raise ValueError( + f"flash_optimum for BetterTransformers may not be used with {torch.__version__}" + ) # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25