experimental expansion of ctx len

This commit is contained in:
Wing Lian
2023-05-31 16:51:19 -04:00
parent 71a43f8479
commit 488a67d75a
2 changed files with 57 additions and 19 deletions

View File

@@ -6,22 +6,20 @@ import os
import random import random
import signal import signal
import sys import sys
from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import fire import fire
import torch import torch
import yaml import yaml
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config from axolotl.utils.validation import validate_config
@@ -204,9 +202,19 @@ def train(
if check_not_in( if check_not_in(
["inference", "shard", "merge_lora"], kwargs ["inference", "shard", "merge_lora"], kwargs
): # don't need to load dataset for these ): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets( if not cfg.pretraining_dataset:
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH train_dataset, eval_dataset = load_prepare_datasets(
) tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
if cfg.pretraining_dataset is True:
pretraining_dataset = "togethercomputer/RedPajama-Data-1T"
else:
pretraining_dataset = cfg.pretraining_dataset
train_dataset = load_pretraining_dataset(
pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
)
eval_dataset = None
if cfg.debug or "debug" in kwargs: if cfg.debug or "debug" in kwargs:
logging.info("check_dataset_labels...") logging.info("check_dataset_labels...")
@@ -256,7 +264,7 @@ def train(
logging.info("check_dataset_labels...") logging.info("check_dataset_labels...")
check_dataset_labels( check_dataset_labels(
train_dataset.select( train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for i in range(5)] [random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec
), ),
tokenizer, tokenizer,
) )
@@ -265,10 +273,7 @@ def train(
logging.info("Finished preparing dataset. Exiting...") logging.info("Finished preparing dataset. Exiting...")
return return
try: model.train()
model.train()
except:
pass
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
@@ -285,14 +290,15 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0: if cfg.local_rank == 0:
def terminate_handler(signum, frame, model):
def terminate_handler(_, __, model):
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir) model.save_pretrained(cfg.output_dir)
sys.exit(0) sys.exit(0)
signal.signal( signal.signal(
signal.SIGINT, signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
lambda signum, frame: terminate_handler(signum, frame, model)
) )
logging.info("Starting trainer...") logging.info("Starting trainer...")
@@ -316,7 +322,9 @@ def train(
if not Path(cfg.output_dir).is_dir(): if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True) os.makedirs(cfg.output_dir, exist_ok=True)
if cfg.flash_optimum: if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else: else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -5,7 +5,8 @@ from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union from typing import List, Tuple, Union
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk import torch
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -392,3 +393,32 @@ def load_prepare_datasets(
eval_dataset = dataset["test"] eval_dataset = dataset["test"]
return train_dataset, eval_dataset return train_dataset, eval_dataset
class PretrainingDatasetWrapper(IterableDataset):
"""
Wrapper for pretraining dataset that avoids loading the dataset into memory
"""
def __init__(self, tokenizer, dataset_path, max_tokens=2048):
self.tokenizer = tokenizer
self.dataset_path = dataset_path
self.max_tokens = max_tokens
def __iter__(self):
buffer = []
for sample in load_dataset(
self.dataset_path,
name="all",
split="train",
streaming=True,
).shuffle(buffer_size=10000):
buffer += self.tokenizer(sample["text"])["input_ids"]
buffer += [self.tokenizer.eos_token_id]
while len(buffer) > self.max_tokens:
yield torch.tensor(buffer[: self.max_tokens])
buffer = buffer[self.max_tokens :]
def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens)