add streaming dataset support for pretraining datasets
This commit is contained in:
@@ -14,7 +14,6 @@ import torch
|
||||
import yaml
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from datasets import Dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
@@ -208,14 +207,11 @@ def train(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
else:
|
||||
if cfg.pretraining_dataset is True:
|
||||
pretraining_dataset = "togethercomputer/RedPajama-Data-1T"
|
||||
else:
|
||||
pretraining_dataset = cfg.pretraining_dataset
|
||||
train_dataset = load_pretraining_dataset(
|
||||
pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
|
||||
cfg.pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
|
||||
)
|
||||
train_dataset = Dataset.from_list(list(train_dataset))
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
|
||||
if cfg.debug or "debug" in kwargs:
|
||||
@@ -262,19 +258,6 @@ def train(
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
return
|
||||
|
||||
if cfg.debug:
|
||||
logging.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
train_dataset.select(
|
||||
[random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec
|
||||
),
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
if prepare_ds_only:
|
||||
logging.info("Finished preparing dataset. Exiting...")
|
||||
return
|
||||
|
||||
model.train()
|
||||
|
||||
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
||||
|
||||
Reference in New Issue
Block a user