more tweaks to do pre-training with bettertransformers

This commit is contained in:
Wing Lian
2023-05-31 21:59:15 -04:00
parent 488a67d75a
commit 1210dc8fd5
6 changed files with 54 additions and 12 deletions

View File

@@ -14,6 +14,7 @@ import torch
import yaml
# add src to the pythonpath so we don't need to pip install this
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
@@ -214,6 +215,7 @@ def train(
train_dataset = load_pretraining_dataset(
pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
)
train_dataset = Dataset.from_list(list(train_dataset))
eval_dataset = None
if cfg.debug or "debug" in kwargs: