Fix pretraining with iterable/streaming Dataset (#556)

* return without packing prep/len

* fix remove columns

* fix encode arguments

* add error when max steps not set

* fix test

---------

Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
This commit is contained in:
Jan Philipp Harries
2023-09-13 06:16:40 +02:00
committed by GitHub
parent 9845c5e12d
commit 2f586d18db
3 changed files with 19 additions and 6 deletions

View File

@@ -191,6 +191,10 @@ def validate_config(cfg):
LOG.warning( LOG.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely." "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
) )
if cfg.pretraining_dataset and not cfg.max_steps:
raise ValueError(
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
)
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and ( if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer not cfg.optimizer or "adamw" not in cfg.optimizer

View File

@@ -3,7 +3,7 @@ import functools
import hashlib import hashlib
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import Dict, List, Tuple, Union
import torch import torch
from datasets import ( from datasets import (
@@ -74,6 +74,7 @@ def prepare_dataset(cfg, tokenizer):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
return train_dataset, eval_dataset, cfg.max_steps
with zero_first(is_main_process()): with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
@@ -527,9 +528,11 @@ def load_prepare_datasets(
return train_dataset, eval_dataset return train_dataset, eval_dataset
def encode_pretraining(tokenizer, max_tokens, examples): def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
) -> Dict[str, List]:
res = tokenizer( res = tokenizer(
examples["text"], examples,
truncation=True, truncation=True,
max_length=max_tokens - 2, max_length=max_tokens - 2,
add_special_tokens=True, add_special_tokens=True,
@@ -637,6 +640,12 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens) encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train") dataset = load_dataset(path, streaming=True, split="train")
dataset = dataset.shuffle(seed=seed, buffer_size=10_000) dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
# TODO dynamically figure out which columns/features to remove dataset = dataset.map(
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"]) encode,
batched=True,
input_columns="text",
remove_columns=[
"text",
],
)
return dataset return dataset

View File

@@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase):
"hello, hello", "hello, hello",
] ]
} }
result = encode_pretraining(self.tokenizer, self.max_tokens, examples) result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
self.assertEqual(len(result["input_ids"]), 3) self.assertEqual(len(result["input_ids"]), 3)