Pretrain transforms (#1261)

* wip for pretraining/iterable data with arbitrary prompt strategies

* more fixes, wip

* more fixes for custom pretraining

* iterable ds wrapper not needed

* remove extra features

* chore: lint

* update pretraning example yml

* fix order for partials

* fixup for tests
This commit is contained in:
Wing Lian
2024-02-06 00:37:03 -05:00
committed by GitHub
parent 8c2e05ade3
commit c7cf3810bd
5 changed files with 145 additions and 62 deletions

View File

@@ -12,6 +12,7 @@ max_steps: 200
pretraining_dataset: pretraining_dataset:
path: c4 path: c4
name: en name: en
type: pretrain
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./model-out output_dir: ./model-out

View File

@@ -31,7 +31,7 @@ class TokenizedPromptDataset(Dataset):
def __init__( # pylint: disable=super-init-not-called def __init__( # pylint: disable=super-init-not-called
self, self,
prompt_tokenizer: PromptTokenizingStrategy, prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset, dataset: Dataset,
process_count: Optional[int] = None, process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False, keep_in_memory: Optional[bool] = False,
**kwargs, **kwargs,

View File

@@ -0,0 +1,58 @@
"""pretraining prompt strategies"""
from typing import Generator
from transformers import BatchEncoding
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
class PretrainTokenizer:
"""basic tokenization class for pretraining"""
def build_prompt(self, prompt) -> Generator[str, None, None]:
yield prompt
class PretrainTokenizationStrategy(PromptTokenizingStrategy):
"""handles tokenization for pretraining with strides"""
@property
def supports_batched(self):
return True
def __init__(self, *args, max_length=None, **kwargs):
super().__init__(*args, **kwargs)
if max_length:
self.max_length = max_length
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
res = self.tokenizer(
prompt,
truncation=True,
max_length=self.max_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
res["input_ids"] = [
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
]
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
return res
def tokenize_prompt(self, prompt):
return self._tokenize(prompt["text"])
def load(tokenizer, cfg):
strat = PretrainTokenizationStrategy(
PretrainTokenizer(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
max_length=cfg.sequence_len * 64,
)
return strat

View File

@@ -4,7 +4,7 @@ import hashlib
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import yaml import yaml
@@ -88,12 +88,21 @@ def prepare_dataset(cfg, tokenizer):
path = cfg.pretraining_dataset[0]["path"] path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"] name = cfg.pretraining_dataset[0]["name"]
train_dataset = load_pretraining_dataset( ds_wrapper_partial = functools.partial(
path, get_dataset_wrapper,
cfg.pretraining_dataset[0],
tokenizer, tokenizer,
cfg, cfg,
name=name, cfg.pretraining_dataset[0]["type"] or "pretrain",
)
train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split="train", name=name),
tokenizer,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len, max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42, seed=cfg.seed or 42,
) )
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
@@ -383,9 +392,9 @@ def load_tokenized_prepared_datasets(
dataset_wrapper, dataset_prompter = get_dataset_wrapper( dataset_wrapper, dataset_prompter = get_dataset_wrapper(
config_dataset=config_dataset, config_dataset=config_dataset,
dataset=ds,
tokenizer=tokenizer, tokenizer=tokenizer,
cfg=cfg, cfg=cfg,
dataset=ds,
d_base_type=d_base_type, d_base_type=d_base_type,
d_prompt_style=d_prompt_style, d_prompt_style=d_prompt_style,
) )
@@ -496,7 +505,12 @@ def load_prepare_datasets(
def get_dataset_wrapper( def get_dataset_wrapper(
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style config_dataset,
tokenizer,
cfg,
d_base_type,
dataset,
d_prompt_style=None,
): ):
dataset_wrapper = None dataset_wrapper = None
dataset_prompter = None dataset_prompter = None
@@ -507,7 +521,8 @@ def get_dataset_wrapper(
} }
if ( if (
"input_ids" in dataset.features isinstance(dataset, Dataset)
and "input_ids" in dataset.features
and "attention_mask" in dataset.features and "attention_mask" in dataset.features
and "labels" in dataset.features and "labels" in dataset.features
): ):
@@ -765,69 +780,60 @@ def encode_pretraining(
return ret return ret
def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42): def wrap_pretraining_dataset(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
max_tokens=2048,
batch_size=1,
seed=42,
buffer_size=10_000,
):
if cfg.sample_packing: if cfg.sample_packing:
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer, tokenizer,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
pad_to_multiple_of=max_tokens * cfg.micro_batch_size, pad_to_multiple_of=max_tokens * batch_size,
) )
encode = functools.partial( encode = functools.partial(
encode_packed_pretraining, encode_packed_pretraining,
tokenizer,
collate_fn, collate_fn,
ds_wrapper_fn,
max_seq_length=max_tokens, max_seq_length=max_tokens,
batch_size=cfg.micro_batch_size, batch_size=batch_size,
) )
# set this to 1 so downstream data_loader doesn't try to increase the batch again # set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1 cfg.micro_batch_size = 1
else: else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens) encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train", name=name) dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
dataset = dataset.map( dataset = dataset.map(
encode, encode,
batched=True, batched=True,
batch_size=10_000, batch_size=buffer_size,
input_columns="text", # input_columns="text",
# remove all the existing columns after mapping since they end up having # remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column # a different length than the encoded/tokenized column
remove_columns=dataset.features.keys(), remove_columns=dataset.features.keys(),
desc="Encoding Pretraining",
) )
return dataset return dataset
def encode_packed_pretraining( def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase,
collate_fn, collate_fn,
examples: List[str], ds_wrapper: Callable,
examples: Dict[str, List],
max_seq_length: int = 2048, max_seq_length: int = 2048,
batch_size: int = 4, batch_size: int = 4,
) -> Dict[str, List]: ) -> Dict[str, List]:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
# tokenize all the examples # tokenize all the examples
# rows get split with stride (overlap) # rows get split with stride (overlap)
res = tokenizer( train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
examples,
truncation=True,
max_length=max_seq_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
attention_mask = [seq + [1] for seq in res["attention_mask"]]
tokenized_examples = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
train_dataset = Dataset.from_dict(tokenized_examples)
train_dataset = process_pretraining_datasets_for_packing( train_dataset = process_pretraining_datasets_for_packing(
train_dataset, max_seq_length train_dataset, max_seq_length
) )
@@ -845,7 +851,14 @@ def encode_packed_pretraining(
for batch in sampler: for batch in sampler:
for data in batch: for data in batch:
features = train_dataset[data] features = train_dataset[data]
features["labels"] = features["input_ids"].copy() if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "overflow_to_sample_mapping" in features:
del features["overflow_to_sample_mapping"]
if "labels" not in features:
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features) collated_features = collate_fn(features)
for feature in features.keys(): for feature in features.keys():

View File

@@ -1,14 +1,14 @@
"""Module for testing streaming dataset sequence packing""" """Module for testing streaming dataset sequence packing"""
import functools
import unittest import unittest
from functools import partial
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
from axolotl.utils.data import encode_packed_pretraining from axolotl.utils.dict import DictDefault
class TestPretrainingPacking(unittest.TestCase): class TestPretrainingPacking(unittest.TestCase):
@@ -20,8 +20,6 @@ class TestPretrainingPacking(unittest.TestCase):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.pad_token = "</s>" self.tokenizer.pad_token = "</s>"
self.max_seq_length = 2048
self.batch_size = 2
def test_packing_stream_dataset(self): def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -31,30 +29,43 @@ class TestPretrainingPacking(unittest.TestCase):
streaming=True, streaming=True,
)["train"] )["train"]
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( cfg = DictDefault(
self.tokenizer, {
return_tensors="pt", "pretraining_dataset": [
padding=True, {
pad_to_multiple_of=self.max_seq_length, "path": "c4",
"name": "en",
"type": "pretrain",
}
],
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"micro_batch_size": 2,
}
) )
encode = partial( ds_wrapper_partial = functools.partial(
encode_packed_pretraining, get_dataset_wrapper,
cfg.pretraining_dataset[0],
self.tokenizer, self.tokenizer,
collate_fn, cfg,
max_seq_length=self.max_seq_length, cfg.pretraining_dataset[0]["type"] or "pretrain",
batch_size=self.batch_size,
) )
dataset = dataset.map( original_bsz = cfg.micro_batch_size
encode, train_dataset = wrap_pretraining_dataset(
batched=True, dataset,
input_columns="text", self.tokenizer,
remove_columns=dataset.features.keys(), cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
) )
trainer_loader = DataLoader( trainer_loader = DataLoader(
dataset, train_dataset,
batch_size=1, batch_size=1,
collate_fn=None, collate_fn=None,
drop_last=True, drop_last=True,
@@ -64,16 +75,16 @@ class TestPretrainingPacking(unittest.TestCase):
if idx > 10: if idx > 10:
break break
assert data["input_ids"].shape == torch.Size( assert data["input_ids"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length] [1, original_bsz * cfg.sequence_len]
) )
assert data["position_ids"].shape == torch.Size( assert data["position_ids"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length] [1, original_bsz * cfg.sequence_len]
) )
assert data["labels"].shape == torch.Size( assert data["labels"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length] [1, original_bsz * cfg.sequence_len]
) )
assert data["attention_mask"].shape == torch.Size( assert data["attention_mask"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length] [1, original_bsz * cfg.sequence_len]
) )
idx += 1 idx += 1