Pretrain transforms (#1261)

* wip for pretraining/iterable data with arbitrary prompt strategies

* more fixes, wip

* more fixes for custom pretraining

* iterable ds wrapper not needed

* remove extra features

* chore: lint

* update pretraning example yml

* fix order for partials

* fixup for tests
This commit is contained in:
Wing Lian
2024-02-06 00:37:03 -05:00
committed by GitHub
parent 8c2e05ade3
commit c7cf3810bd
5 changed files with 145 additions and 62 deletions

View File

@@ -12,6 +12,7 @@ max_steps: 200
pretraining_dataset:
path: c4
name: en
type: pretrain
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./model-out

View File

@@ -31,7 +31,7 @@ class TokenizedPromptDataset(Dataset):
def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
dataset: Dataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,

View File

@@ -0,0 +1,58 @@
"""pretraining prompt strategies"""
from typing import Generator
from transformers import BatchEncoding
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
class PretrainTokenizer:
"""basic tokenization class for pretraining"""
def build_prompt(self, prompt) -> Generator[str, None, None]:
yield prompt
class PretrainTokenizationStrategy(PromptTokenizingStrategy):
"""handles tokenization for pretraining with strides"""
@property
def supports_batched(self):
return True
def __init__(self, *args, max_length=None, **kwargs):
super().__init__(*args, **kwargs)
if max_length:
self.max_length = max_length
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
res = self.tokenizer(
prompt,
truncation=True,
max_length=self.max_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
res["input_ids"] = [
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
]
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
return res
def tokenize_prompt(self, prompt):
return self._tokenize(prompt["text"])
def load(tokenizer, cfg):
strat = PretrainTokenizationStrategy(
PretrainTokenizer(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
max_length=cfg.sequence_len * 64,
)
return strat

View File

@@ -4,7 +4,7 @@ import hashlib
import logging
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import yaml
@@ -88,12 +88,21 @@ def prepare_dataset(cfg, tokenizer):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
train_dataset = load_pretraining_dataset(
path,
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
tokenizer,
cfg,
name=name,
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split="train", name=name),
tokenizer,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
@@ -383,9 +392,9 @@ def load_tokenized_prepared_datasets(
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
config_dataset=config_dataset,
dataset=ds,
tokenizer=tokenizer,
cfg=cfg,
dataset=ds,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
)
@@ -496,7 +505,12 @@ def load_prepare_datasets(
def get_dataset_wrapper(
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
config_dataset,
tokenizer,
cfg,
d_base_type,
dataset,
d_prompt_style=None,
):
dataset_wrapper = None
dataset_prompter = None
@@ -507,7 +521,8 @@ def get_dataset_wrapper(
}
if (
"input_ids" in dataset.features
isinstance(dataset, Dataset)
and "input_ids" in dataset.features
and "attention_mask" in dataset.features
and "labels" in dataset.features
):
@@ -765,69 +780,60 @@ def encode_pretraining(
return ret
def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
def wrap_pretraining_dataset(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
max_tokens=2048,
batch_size=1,
seed=42,
buffer_size=10_000,
):
if cfg.sample_packing:
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
pad_to_multiple_of=max_tokens * batch_size,
)
encode = functools.partial(
encode_packed_pretraining,
tokenizer,
collate_fn,
ds_wrapper_fn,
max_seq_length=max_tokens,
batch_size=cfg.micro_batch_size,
batch_size=batch_size,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train", name=name)
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
dataset = dataset.map(
encode,
batched=True,
batch_size=10_000,
input_columns="text",
batch_size=buffer_size,
# input_columns="text",
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
remove_columns=dataset.features.keys(),
desc="Encoding Pretraining",
)
return dataset
def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase,
collate_fn,
examples: List[str],
ds_wrapper: Callable,
examples: Dict[str, List],
max_seq_length: int = 2048,
batch_size: int = 4,
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
res = tokenizer(
examples,
truncation=True,
max_length=max_seq_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
attention_mask = [seq + [1] for seq in res["attention_mask"]]
tokenized_examples = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
train_dataset = Dataset.from_dict(tokenized_examples)
train_dataset = process_pretraining_datasets_for_packing(
train_dataset, max_seq_length
)
@@ -845,7 +851,14 @@ def encode_packed_pretraining(
for batch in sampler:
for data in batch:
features = train_dataset[data]
features["labels"] = features["input_ids"].copy()
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "overflow_to_sample_mapping" in features:
del features["overflow_to_sample_mapping"]
if "labels" not in features:
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)
for feature in features.keys():

View File

@@ -1,14 +1,14 @@
"""Module for testing streaming dataset sequence packing"""
import functools
import unittest
from functools import partial
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data import encode_packed_pretraining
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
from axolotl.utils.dict import DictDefault
class TestPretrainingPacking(unittest.TestCase):
@@ -20,8 +20,6 @@ class TestPretrainingPacking(unittest.TestCase):
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.pad_token = "</s>"
self.max_seq_length = 2048
self.batch_size = 2
def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code
@@ -31,30 +29,43 @@ class TestPretrainingPacking(unittest.TestCase):
streaming=True,
)["train"]
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=self.max_seq_length,
cfg = DictDefault(
{
"pretraining_dataset": [
{
"path": "c4",
"name": "en",
"type": "pretrain",
}
],
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"micro_batch_size": 2,
}
)
encode = partial(
encode_packed_pretraining,
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
self.tokenizer,
collate_fn,
max_seq_length=self.max_seq_length,
batch_size=self.batch_size,
cfg,
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
dataset = dataset.map(
encode,
batched=True,
input_columns="text",
remove_columns=dataset.features.keys(),
original_bsz = cfg.micro_batch_size
train_dataset = wrap_pretraining_dataset(
dataset,
self.tokenizer,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
)
trainer_loader = DataLoader(
dataset,
train_dataset,
batch_size=1,
collate_fn=None,
drop_last=True,
@@ -64,16 +75,16 @@ class TestPretrainingPacking(unittest.TestCase):
if idx > 10:
break
assert data["input_ids"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
[1, original_bsz * cfg.sequence_len]
)
assert data["position_ids"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
[1, original_bsz * cfg.sequence_len]
)
assert data["labels"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
[1, original_bsz * cfg.sequence_len]
)
assert data["attention_mask"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
[1, original_bsz * cfg.sequence_len]
)
idx += 1