Fix pretraining with iterable/streaming Dataset (#556)
* return without packing prep/len * fix remove columns * fix encode arguments * add error when max steps not set * fix test --------- Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
9845c5e12d
commit
2f586d18db
@@ -191,6 +191,10 @@ def validate_config(cfg):
|
||||
LOG.warning(
|
||||
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
||||
)
|
||||
if cfg.pretraining_dataset and not cfg.max_steps:
|
||||
raise ValueError(
|
||||
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
|
||||
)
|
||||
|
||||
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
||||
not cfg.optimizer or "adamw" not in cfg.optimizer
|
||||
|
||||
@@ -3,7 +3,7 @@ import functools
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from datasets import (
|
||||
@@ -74,6 +74,7 @@ def prepare_dataset(cfg, tokenizer):
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
return train_dataset, eval_dataset, cfg.max_steps
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
@@ -527,9 +528,11 @@ def load_prepare_datasets(
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def encode_pretraining(tokenizer, max_tokens, examples):
|
||||
def encode_pretraining(
|
||||
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
|
||||
) -> Dict[str, List]:
|
||||
res = tokenizer(
|
||||
examples["text"],
|
||||
examples,
|
||||
truncation=True,
|
||||
max_length=max_tokens - 2,
|
||||
add_special_tokens=True,
|
||||
@@ -637,6 +640,12 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||
dataset = load_dataset(path, streaming=True, split="train")
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
||||
# TODO dynamically figure out which columns/features to remove
|
||||
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
|
||||
dataset = dataset.map(
|
||||
encode,
|
||||
batched=True,
|
||||
input_columns="text",
|
||||
remove_columns=[
|
||||
"text",
|
||||
],
|
||||
)
|
||||
return dataset
|
||||
|
||||
@@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase):
|
||||
"hello, hello",
|
||||
]
|
||||
}
|
||||
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
|
||||
result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
|
||||
|
||||
self.assertEqual(len(result["input_ids"]), 3)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user