diff --git a/scripts/finetune.py b/scripts/finetune.py index 99236b087..88815dfdd 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -6,22 +6,20 @@ import os import random import signal import sys -from functools import partial from pathlib import Path from typing import Any, Dict, List, Optional, Union import fire import torch import yaml -from transformers import GenerationConfig, TextStreamer - -from axolotl.utils.data import load_prepare_datasets -from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_tokenizer # add src to the pythonpath so we don't need to pip install this from optimum.bettertransformer import BetterTransformer +from transformers import GenerationConfig, TextStreamer +from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.trainer import setup_trainer from axolotl.utils.validation import validate_config @@ -204,9 +202,19 @@ def train( if check_not_in( ["inference", "shard", "merge_lora"], kwargs ): # don't need to load dataset for these - train_dataset, eval_dataset = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH - ) + if not cfg.pretraining_dataset: + train_dataset, eval_dataset = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH + ) + else: + if cfg.pretraining_dataset is True: + pretraining_dataset = "togethercomputer/RedPajama-Data-1T" + else: + pretraining_dataset = cfg.pretraining_dataset + train_dataset = load_pretraining_dataset( + pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len + ) + eval_dataset = None if cfg.debug or "debug" in kwargs: logging.info("check_dataset_labels...") @@ -256,7 +264,7 @@ def train( logging.info("check_dataset_labels...") check_dataset_labels( train_dataset.select( - [random.randrange(0, len(train_dataset) - 1) for i in range(5)] + [random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec ), tokenizer, ) @@ -265,10 +273,7 @@ def train( logging.info("Finished preparing dataset. Exiting...") return - try: - model.train() - except: - pass + model.train() trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) @@ -285,14 +290,15 @@ def train( # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model if cfg.local_rank == 0: - def terminate_handler(signum, frame, model): + + def terminate_handler(_, __, model): if cfg.flash_optimum: model = BetterTransformer.reverse(model) model.save_pretrained(cfg.output_dir) sys.exit(0) + signal.signal( - signal.SIGINT, - lambda signum, frame: terminate_handler(signum, frame, model) + signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) ) logging.info("Starting trainer...") @@ -316,7 +322,9 @@ def train( if not Path(cfg.output_dir).is_dir(): os.makedirs(cfg.output_dir, exist_ok=True) if cfg.flash_optimum: - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): + with torch.backends.cuda.sdp_kernel( + enable_flash=True, enable_math=True, enable_mem_efficient=True + ): trainer.train(resume_from_checkpoint=resume_from_checkpoint) else: trainer.train(resume_from_checkpoint=resume_from_checkpoint) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index cba964076..49314372a 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -5,7 +5,8 @@ from hashlib import md5 from pathlib import Path from typing import List, Tuple, Union -from datasets import Dataset, DatasetDict, load_dataset, load_from_disk +import torch +from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk from huggingface_hub import hf_hub_download from transformers import PreTrainedTokenizerBase @@ -392,3 +393,32 @@ def load_prepare_datasets( eval_dataset = dataset["test"] return train_dataset, eval_dataset + + +class PretrainingDatasetWrapper(IterableDataset): + """ + Wrapper for pretraining dataset that avoids loading the dataset into memory + """ + + def __init__(self, tokenizer, dataset_path, max_tokens=2048): + self.tokenizer = tokenizer + self.dataset_path = dataset_path + self.max_tokens = max_tokens + + def __iter__(self): + buffer = [] + for sample in load_dataset( + self.dataset_path, + name="all", + split="train", + streaming=True, + ).shuffle(buffer_size=10000): + buffer += self.tokenizer(sample["text"])["input_ids"] + buffer += [self.tokenizer.eos_token_id] + while len(buffer) > self.max_tokens: + yield torch.tensor(buffer[: self.max_tokens]) + buffer = buffer[self.max_tokens :] + + +def load_pretraining_dataset(path, tokenizer, max_tokens=2048): + return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens)