experimental expansion of ctx len

This commit is contained in:
Wing Lian
2023-05-31 16:51:19 -04:00
parent 71a43f8479
commit 488a67d75a
2 changed files with 57 additions and 19 deletions

View File

@@ -6,22 +6,20 @@ import os
import random
import signal
import sys
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import fire
import torch
import yaml
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config
@@ -204,9 +202,19 @@ def train(
if check_not_in(
["inference", "shard", "merge_lora"], kwargs
): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if not cfg.pretraining_dataset:
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
if cfg.pretraining_dataset is True:
pretraining_dataset = "togethercomputer/RedPajama-Data-1T"
else:
pretraining_dataset = cfg.pretraining_dataset
train_dataset = load_pretraining_dataset(
pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
)
eval_dataset = None
if cfg.debug or "debug" in kwargs:
logging.info("check_dataset_labels...")
@@ -256,7 +264,7 @@ def train(
logging.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
[random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec
),
tokenizer,
)
@@ -265,10 +273,7 @@ def train(
logging.info("Finished preparing dataset. Exiting...")
return
try:
model.train()
except:
pass
model.train()
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
@@ -285,14 +290,15 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(signum, frame, model):
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
sys.exit(0)
signal.signal(
signal.SIGINT,
lambda signum, frame: terminate_handler(signum, frame, model)
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
logging.info("Starting trainer...")
@@ -316,7 +322,9 @@ def train(
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)