Fix pretraining with iterable/streaming Dataset (#556)
* return without packing prep/len * fix remove columns * fix encode arguments * add error when max steps not set * fix test --------- Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
9845c5e12d
commit
2f586d18db
@@ -191,6 +191,10 @@ def validate_config(cfg):
|
|||||||
LOG.warning(
|
LOG.warning(
|
||||||
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
||||||
)
|
)
|
||||||
|
if cfg.pretraining_dataset and not cfg.max_steps:
|
||||||
|
raise ValueError(
|
||||||
|
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
|
||||||
|
)
|
||||||
|
|
||||||
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
||||||
not cfg.optimizer or "adamw" not in cfg.optimizer
|
not cfg.optimizer or "adamw" not in cfg.optimizer
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import functools
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import (
|
from datasets import (
|
||||||
@@ -74,6 +74,7 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
return train_dataset, eval_dataset, cfg.max_steps
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
@@ -527,9 +528,11 @@ def load_prepare_datasets(
|
|||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
|
||||||
def encode_pretraining(tokenizer, max_tokens, examples):
|
def encode_pretraining(
|
||||||
|
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
|
||||||
|
) -> Dict[str, List]:
|
||||||
res = tokenizer(
|
res = tokenizer(
|
||||||
examples["text"],
|
examples,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_tokens - 2,
|
max_length=max_tokens - 2,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
@@ -637,6 +640,12 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
|||||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||||
dataset = load_dataset(path, streaming=True, split="train")
|
dataset = load_dataset(path, streaming=True, split="train")
|
||||||
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
||||||
# TODO dynamically figure out which columns/features to remove
|
dataset = dataset.map(
|
||||||
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
|
encode,
|
||||||
|
batched=True,
|
||||||
|
input_columns="text",
|
||||||
|
remove_columns=[
|
||||||
|
"text",
|
||||||
|
],
|
||||||
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase):
|
|||||||
"hello, hello",
|
"hello, hello",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
|
result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
|
||||||
|
|
||||||
self.assertEqual(len(result["input_ids"]), 3)
|
self.assertEqual(len(result["input_ids"]), 3)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user