diff --git a/scripts/finetune.py b/scripts/finetune.py index 88815dfdd..9bed61ca4 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -14,6 +14,7 @@ import torch import yaml # add src to the pythonpath so we don't need to pip install this +from datasets import Dataset from optimum.bettertransformer import BetterTransformer from transformers import GenerationConfig, TextStreamer @@ -214,6 +215,7 @@ def train( train_dataset = load_pretraining_dataset( pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len ) + train_dataset = Dataset.from_list(list(train_dataset)) eval_dataset = None if cfg.debug or "debug" in kwargs: diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index f6852249a..ab197304c 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -2,6 +2,7 @@ import os +from optimum.bettertransformer import BetterTransformer from transformers import ( TrainerCallback, TrainerControl, @@ -30,3 +31,26 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public- kwargs["model"].save_pretrained(peft_model_path) return control + + +class SaveBetterTransformerModelCallback( + TrainerCallback +): # pylint: disable=too-few-public-methods + """Callback to save the BatterTransformer wrapped model""" + + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + checkpoint_folder = os.path.join( + args.output_dir, + f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", + ) + + model = BetterTransformer.reverse(kwargs["model"]) + model.save_pretrained(checkpoint_folder) + + return control diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 49314372a..164296ee2 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -409,14 +409,16 @@ class PretrainingDatasetWrapper(IterableDataset): buffer = [] for sample in load_dataset( self.dataset_path, - name="all", - split="train", - streaming=True, - ).shuffle(buffer_size=10000): + )["train"].shuffle(): buffer += self.tokenizer(sample["text"])["input_ids"] buffer += [self.tokenizer.eos_token_id] while len(buffer) > self.max_tokens: - yield torch.tensor(buffer[: self.max_tokens]) + input_ids = torch.tensor(buffer[: self.max_tokens]) + yield { + "input_ids": input_ids, + "attention_mask": torch.ones(input_ids.size()), + "labels": input_ids, + } buffer = buffer[self.max_tokens :] diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 11b4629ec..91ef96ca9 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -10,8 +10,8 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401 import bitsandbytes as bnb import torch import transformers -from transformers import PreTrainedModel # noqa: F401 from optimum.bettertransformer import BetterTransformer +from transformers import PreTrainedModel # noqa: F401 from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -136,7 +136,7 @@ def load_model( logging.info("patching with xpos rope") replace_llama_rope_with_xpos_rope() - if cfg.bf16: + if cfg.bf16 or cfg.bfloat16: torch_dtype = torch.bfloat16 elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: torch_dtype = torch.float16 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 9ae1e7e93..b7823fea4 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -16,7 +16,10 @@ from torch.optim.lr_scheduler import OneCycleLR from transformers import EarlyStoppingCallback, Trainer from transformers.trainer_pt_utils import get_parameter_names -from axolotl.utils.callbacks import SavePeftModelCallback +from axolotl.utils.callbacks import ( + SaveBetterTransformerModelCallback, + SavePeftModelCallback, +) from axolotl.utils.schedulers import InterpolatingLogScheduler @@ -228,6 +231,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ]: # only save in rank 0 callbacks.append(SavePeftModelCallback) + if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: + callbacks.append(SaveBetterTransformerModelCallback) + data_collator_kwargs = { "padding": True, } diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index db19900cc..abaaba8d0 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -1,8 +1,10 @@ """Module for validating config files""" import logging + import torch + def validate_config(cfg): if cfg.gradient_accumulation_steps and cfg.batch_size: raise ValueError( @@ -59,14 +61,20 @@ def validate_config(cfg): if cfg.flash_optimum is True: if cfg.adapter: - logging.warning("BetterTransformers probably doesn't work with PEFT adapters") + logging.warning( + "BetterTransformers probably doesn't work with PEFT adapters" + ) if cfg.fp16 or cfg.bf16: raise ValueError("AMP is not supported with BetterTransformer") if cfg.float16 is not True: - logging.warning("You should probably set float16 to true to load the model in float16 for BetterTransformers") - if torch.__version__.split(".")[0] < 2: + logging.warning( + "You should probably set float16 to true to load the model in float16 for BetterTransformers" + ) + if int(torch.__version__.split(".")[0]) < 2: logging.warning("torch>=2.0.0 required") - raise ValueError(f"flash_optimum for BetterTransformers may not be used with {torch.__version__}") + raise ValueError( + f"flash_optimum for BetterTransformers may not be used with {torch.__version__}" + ) # TODO # MPT 7b