more tweaks to do pre-training with bettertransformers

This commit is contained in:
Wing Lian
2023-05-31 21:59:15 -04:00
parent ed7531abb8
commit 86bd9fcff4
6 changed files with 54 additions and 12 deletions

View File

@@ -14,6 +14,7 @@ import torch
import yaml
# add src to the pythonpath so we don't need to pip install this
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig
@@ -204,6 +205,7 @@ def train(
train_dataset = load_pretraining_dataset(
pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
)
train_dataset = Dataset.from_list(list(train_dataset))
eval_dataset = None
if cfg.debug or "debug" in kwargs: