Fix pretraining with iterable/streaming Dataset (#556)

* return without packing prep/len

* fix remove columns

* fix encode arguments

* add error when max steps not set

* fix test

---------

Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
This commit is contained in:
Jan Philipp Harries
2023-09-13 06:16:40 +02:00
committed by GitHub
parent 9845c5e12d
commit 2f586d18db
3 changed files with 19 additions and 6 deletions

View File

@@ -191,6 +191,10 @@ def validate_config(cfg):
LOG.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
)
if cfg.pretraining_dataset and not cfg.max_steps:
raise ValueError(
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
)
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer

View File

@@ -3,7 +3,7 @@ import functools
import hashlib
import logging
from pathlib import Path
from typing import Tuple, Union
from typing import Dict, List, Tuple, Union
import torch
from datasets import (
@@ -74,6 +74,7 @@ def prepare_dataset(cfg, tokenizer):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
return train_dataset, eval_dataset, cfg.max_steps
with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing(
@@ -527,9 +528,11 @@ def load_prepare_datasets(
return train_dataset, eval_dataset
def encode_pretraining(tokenizer, max_tokens, examples):
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
) -> Dict[str, List]:
res = tokenizer(
examples["text"],
examples,
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
@@ -637,6 +640,12 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train")
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
# TODO dynamically figure out which columns/features to remove
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
dataset = dataset.map(
encode,
batched=True,
input_columns="text",
remove_columns=[
"text",
],
)
return dataset

View File

@@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase):
"hello, hello",
]
}
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
self.assertEqual(len(result["input_ids"]), 3)