experimental expansion of ctx len
This commit is contained in:
@@ -6,22 +6,20 @@ import os
|
|||||||
import random
|
import random
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from transformers import GenerationConfig
|
|
||||||
|
|
||||||
from axolotl.utils.data import load_prepare_datasets
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
|
||||||
|
|
||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
|
from transformers import GenerationConfig
|
||||||
|
|
||||||
|
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
from axolotl.utils.validation import validate_config
|
from axolotl.utils.validation import validate_config
|
||||||
@@ -194,9 +192,19 @@ def train(
|
|||||||
if check_not_in(
|
if check_not_in(
|
||||||
["inference", "shard", "merge_lora"], kwargs
|
["inference", "shard", "merge_lora"], kwargs
|
||||||
): # don't need to load dataset for these
|
): # don't need to load dataset for these
|
||||||
train_dataset, eval_dataset = load_prepare_datasets(
|
if not cfg.pretraining_dataset:
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
train_dataset, eval_dataset = load_prepare_datasets(
|
||||||
)
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if cfg.pretraining_dataset is True:
|
||||||
|
pretraining_dataset = "togethercomputer/RedPajama-Data-1T"
|
||||||
|
else:
|
||||||
|
pretraining_dataset = cfg.pretraining_dataset
|
||||||
|
train_dataset = load_pretraining_dataset(
|
||||||
|
pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
|
||||||
|
)
|
||||||
|
eval_dataset = None
|
||||||
|
|
||||||
if cfg.debug or "debug" in kwargs:
|
if cfg.debug or "debug" in kwargs:
|
||||||
logging.info("check_dataset_labels...")
|
logging.info("check_dataset_labels...")
|
||||||
@@ -246,7 +254,7 @@ def train(
|
|||||||
logging.info("check_dataset_labels...")
|
logging.info("check_dataset_labels...")
|
||||||
check_dataset_labels(
|
check_dataset_labels(
|
||||||
train_dataset.select(
|
train_dataset.select(
|
||||||
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
|
[random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec
|
||||||
),
|
),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
)
|
)
|
||||||
@@ -255,10 +263,7 @@ def train(
|
|||||||
logging.info("Finished preparing dataset. Exiting...")
|
logging.info("Finished preparing dataset. Exiting...")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
model.train()
|
||||||
model.train()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
||||||
|
|
||||||
@@ -275,14 +280,15 @@ def train(
|
|||||||
|
|
||||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
def terminate_handler(signum, frame, model):
|
|
||||||
|
def terminate_handler(_, __, model):
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
model.save_pretrained(cfg.output_dir)
|
model.save_pretrained(cfg.output_dir)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
signal.signal(
|
signal.signal(
|
||||||
signal.SIGINT,
|
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
||||||
lambda signum, frame: terminate_handler(signum, frame, model)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Starting trainer...")
|
logging.info("Starting trainer...")
|
||||||
@@ -304,7 +310,9 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
with torch.backends.cuda.sdp_kernel(
|
||||||
|
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||||
|
):
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
else:
|
else:
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ from hashlib import md5
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
import torch
|
||||||
|
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
@@ -385,3 +386,32 @@ def load_prepare_datasets(
|
|||||||
eval_dataset = dataset["test"]
|
eval_dataset = dataset["test"]
|
||||||
|
|
||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
|
||||||
|
class PretrainingDatasetWrapper(IterableDataset):
|
||||||
|
"""
|
||||||
|
Wrapper for pretraining dataset that avoids loading the dataset into memory
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer, dataset_path, max_tokens=2048):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.dataset_path = dataset_path
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
buffer = []
|
||||||
|
for sample in load_dataset(
|
||||||
|
self.dataset_path,
|
||||||
|
name="all",
|
||||||
|
split="train",
|
||||||
|
streaming=True,
|
||||||
|
).shuffle(buffer_size=10000):
|
||||||
|
buffer += self.tokenizer(sample["text"])["input_ids"]
|
||||||
|
buffer += [self.tokenizer.eos_token_id]
|
||||||
|
while len(buffer) > self.max_tokens:
|
||||||
|
yield torch.tensor(buffer[: self.max_tokens])
|
||||||
|
buffer = buffer[self.max_tokens :]
|
||||||
|
|
||||||
|
|
||||||
|
def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
|
||||||
|
return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user