more tweaks to do pre-training with bettertransformers

This commit is contained in:
Wing Lian
2023-05-31 21:59:15 -04:00
parent ed7531abb8
commit 86bd9fcff4
6 changed files with 54 additions and 12 deletions

View File

@@ -14,6 +14,7 @@ import torch
import yaml import yaml
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig from transformers import GenerationConfig
@@ -204,6 +205,7 @@ def train(
train_dataset = load_pretraining_dataset( train_dataset = load_pretraining_dataset(
pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
) )
train_dataset = Dataset.from_list(list(train_dataset))
eval_dataset = None eval_dataset = None
if cfg.debug or "debug" in kwargs: if cfg.debug or "debug" in kwargs:

View File

@@ -2,6 +2,7 @@
import os import os
from optimum.bettertransformer import BetterTransformer
from transformers import ( from transformers import (
TrainerCallback, TrainerCallback,
TrainerControl, TrainerControl,
@@ -30,3 +31,26 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
kwargs["model"].save_pretrained(peft_model_path) kwargs["model"].save_pretrained(peft_model_path)
return control return control
class SaveBetterTransformerModelCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods
"""Callback to save the BatterTransformer wrapped model"""
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
return control

View File

@@ -402,14 +402,16 @@ class PretrainingDatasetWrapper(IterableDataset):
buffer = [] buffer = []
for sample in load_dataset( for sample in load_dataset(
self.dataset_path, self.dataset_path,
name="all", )["train"].shuffle():
split="train",
streaming=True,
).shuffle(buffer_size=10000):
buffer += self.tokenizer(sample["text"])["input_ids"] buffer += self.tokenizer(sample["text"])["input_ids"]
buffer += [self.tokenizer.eos_token_id] buffer += [self.tokenizer.eos_token_id]
while len(buffer) > self.max_tokens: while len(buffer) > self.max_tokens:
yield torch.tensor(buffer[: self.max_tokens]) input_ids = torch.tensor(buffer[: self.max_tokens])
yield {
"input_ids": input_ids,
"attention_mask": torch.ones(input_ids.size()),
"labels": input_ids,
}
buffer = buffer[self.max_tokens :] buffer = buffer[self.max_tokens :]

View File

@@ -10,8 +10,8 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
from transformers import PreTrainedModel # noqa: F401
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from transformers import PreTrainedModel # noqa: F401
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
@@ -116,7 +116,7 @@ def load_model(
logging.info("patching with sdp attention") logging.info("patching with sdp attention")
hijack_llama_sdp_attention() hijack_llama_sdp_attention()
if cfg.bf16: if cfg.bf16 or cfg.bfloat16:
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
torch_dtype = torch.float16 torch_dtype = torch.float16

View File

@@ -15,7 +15,10 @@ from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer from transformers import EarlyStoppingCallback, Trainer
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import SavePeftModelCallback from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
from axolotl.utils.schedulers import InterpolatingLogScheduler from axolotl.utils.schedulers import InterpolatingLogScheduler
@@ -225,6 +228,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
]: # only save in rank 0 ]: # only save in rank 0
callbacks.append(SavePeftModelCallback) callbacks.append(SavePeftModelCallback)
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = { data_collator_kwargs = {
"padding": True, "padding": True,
} }

View File

@@ -1,8 +1,10 @@
"""Module for validating config files""" """Module for validating config files"""
import logging import logging
import torch import torch
def validate_config(cfg): def validate_config(cfg):
if cfg.gradient_accumulation_steps and cfg.batch_size: if cfg.gradient_accumulation_steps and cfg.batch_size:
raise ValueError( raise ValueError(
@@ -50,14 +52,20 @@ def validate_config(cfg):
if cfg.flash_optimum is True: if cfg.flash_optimum is True:
if cfg.adapter: if cfg.adapter:
logging.warning("BetterTransformers probably doesn't work with PEFT adapters") logging.warning(
"BetterTransformers probably doesn't work with PEFT adapters"
)
if cfg.fp16 or cfg.bf16: if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer") raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True: if cfg.float16 is not True:
logging.warning("You should probably set float16 to true to load the model in float16 for BetterTransformers") logging.warning(
if torch.__version__.split(".")[0] < 2: "You should probably set float16 to true to load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".")[0]) < 2:
logging.warning("torch>=2.0.0 required") logging.warning("torch>=2.0.0 required")
raise ValueError(f"flash_optimum for BetterTransformers may not be used with {torch.__version__}") raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25