more tweaks to do pre-training with bettertransformers

This commit is contained in:
Wing Lian
2023-05-31 21:59:15 -04:00
parent ed7531abb8
commit 86bd9fcff4
6 changed files with 54 additions and 12 deletions

View File

@@ -14,6 +14,7 @@ import torch
import yaml
# add src to the pythonpath so we don't need to pip install this
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig
@@ -204,6 +205,7 @@ def train(
train_dataset = load_pretraining_dataset(
pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
)
train_dataset = Dataset.from_list(list(train_dataset))
eval_dataset = None
if cfg.debug or "debug" in kwargs:

View File

@@ -2,6 +2,7 @@
import os
from optimum.bettertransformer import BetterTransformer
from transformers import (
TrainerCallback,
TrainerControl,
@@ -30,3 +31,26 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
kwargs["model"].save_pretrained(peft_model_path)
return control
class SaveBetterTransformerModelCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods
"""Callback to save the BatterTransformer wrapped model"""
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
return control

View File

@@ -402,14 +402,16 @@ class PretrainingDatasetWrapper(IterableDataset):
buffer = []
for sample in load_dataset(
self.dataset_path,
name="all",
split="train",
streaming=True,
).shuffle(buffer_size=10000):
)["train"].shuffle():
buffer += self.tokenizer(sample["text"])["input_ids"]
buffer += [self.tokenizer.eos_token_id]
while len(buffer) > self.max_tokens:
yield torch.tensor(buffer[: self.max_tokens])
input_ids = torch.tensor(buffer[: self.max_tokens])
yield {
"input_ids": input_ids,
"attention_mask": torch.ones(input_ids.size()),
"labels": input_ids,
}
buffer = buffer[self.max_tokens :]

View File

@@ -10,8 +10,8 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb
import torch
import transformers
from transformers import PreTrainedModel # noqa: F401
from optimum.bettertransformer import BetterTransformer
from transformers import PreTrainedModel # noqa: F401
from transformers import (
AutoConfig,
AutoModelForCausalLM,
@@ -116,7 +116,7 @@ def load_model(
logging.info("patching with sdp attention")
hijack_llama_sdp_attention()
if cfg.bf16:
if cfg.bf16 or cfg.bfloat16:
torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
torch_dtype = torch.float16

View File

@@ -15,7 +15,10 @@ from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import SavePeftModelCallback
from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
from axolotl.utils.schedulers import InterpolatingLogScheduler
@@ -225,6 +228,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
]: # only save in rank 0
callbacks.append(SavePeftModelCallback)
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = {
"padding": True,
}

View File

@@ -1,8 +1,10 @@
"""Module for validating config files"""
import logging
import torch
def validate_config(cfg):
if cfg.gradient_accumulation_steps and cfg.batch_size:
raise ValueError(
@@ -50,14 +52,20 @@ def validate_config(cfg):
if cfg.flash_optimum is True:
if cfg.adapter:
logging.warning("BetterTransformers probably doesn't work with PEFT adapters")
logging.warning(
"BetterTransformers probably doesn't work with PEFT adapters"
)
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True:
logging.warning("You should probably set float16 to true to load the model in float16 for BetterTransformers")
if torch.__version__.split(".")[0] < 2:
logging.warning(
"You should probably set float16 to true to load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".")[0]) < 2:
logging.warning("torch>=2.0.0 required")
raise ValueError(f"flash_optimum for BetterTransformers may not be used with {torch.__version__}")
raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25