experimental expansion of ctx len

This commit is contained in:
Wing Lian
2023-05-31 16:51:19 -04:00
parent 71a43f8479
commit 488a67d75a
2 changed files with 57 additions and 19 deletions

View File

@@ -6,22 +6,20 @@ import os
import random
import signal
import sys
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import fire
import torch
import yaml
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config
@@ -204,9 +202,19 @@ def train(
if check_not_in(
["inference", "shard", "merge_lora"], kwargs
): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if not cfg.pretraining_dataset:
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
if cfg.pretraining_dataset is True:
pretraining_dataset = "togethercomputer/RedPajama-Data-1T"
else:
pretraining_dataset = cfg.pretraining_dataset
train_dataset = load_pretraining_dataset(
pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
)
eval_dataset = None
if cfg.debug or "debug" in kwargs:
logging.info("check_dataset_labels...")
@@ -256,7 +264,7 @@ def train(
logging.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
[random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec
),
tokenizer,
)
@@ -265,10 +273,7 @@ def train(
logging.info("Finished preparing dataset. Exiting...")
return
try:
model.train()
except:
pass
model.train()
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
@@ -285,14 +290,15 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(signum, frame, model):
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
sys.exit(0)
signal.signal(
signal.SIGINT,
lambda signum, frame: terminate_handler(signum, frame, model)
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
logging.info("Starting trainer...")
@@ -316,7 +322,9 @@ def train(
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -5,7 +5,8 @@ from hashlib import md5
from pathlib import Path
from typing import List, Tuple, Union
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
import torch
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase
@@ -392,3 +393,32 @@ def load_prepare_datasets(
eval_dataset = dataset["test"]
return train_dataset, eval_dataset
class PretrainingDatasetWrapper(IterableDataset):
"""
Wrapper for pretraining dataset that avoids loading the dataset into memory
"""
def __init__(self, tokenizer, dataset_path, max_tokens=2048):
self.tokenizer = tokenizer
self.dataset_path = dataset_path
self.max_tokens = max_tokens
def __iter__(self):
buffer = []
for sample in load_dataset(
self.dataset_path,
name="all",
split="train",
streaming=True,
).shuffle(buffer_size=10000):
buffer += self.tokenizer(sample["text"])["input_ids"]
buffer += [self.tokenizer.eos_token_id]
while len(buffer) > self.max_tokens:
yield torch.tensor(buffer[: self.max_tokens])
buffer = buffer[self.max_tokens :]
def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens)