Pretrain transforms (#1261)

* wip for pretraining/iterable data with arbitrary prompt strategies

* more fixes, wip

* more fixes for custom pretraining

* iterable ds wrapper not needed

* remove extra features

* chore: lint

* update pretraning example yml

* fix order for partials

* fixup for tests
This commit is contained in:
Wing Lian
2024-02-06 00:37:03 -05:00
committed by GitHub
parent 8c2e05ade3
commit c7cf3810bd
5 changed files with 145 additions and 62 deletions

View File

@@ -31,7 +31,7 @@ class TokenizedPromptDataset(Dataset):
def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
dataset: Dataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,

View File

@@ -0,0 +1,58 @@
"""pretraining prompt strategies"""
from typing import Generator
from transformers import BatchEncoding
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
class PretrainTokenizer:
"""basic tokenization class for pretraining"""
def build_prompt(self, prompt) -> Generator[str, None, None]:
yield prompt
class PretrainTokenizationStrategy(PromptTokenizingStrategy):
"""handles tokenization for pretraining with strides"""
@property
def supports_batched(self):
return True
def __init__(self, *args, max_length=None, **kwargs):
super().__init__(*args, **kwargs)
if max_length:
self.max_length = max_length
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
res = self.tokenizer(
prompt,
truncation=True,
max_length=self.max_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
res["input_ids"] = [
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
]
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
return res
def tokenize_prompt(self, prompt):
return self._tokenize(prompt["text"])
def load(tokenizer, cfg):
strat = PretrainTokenizationStrategy(
PretrainTokenizer(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
max_length=cfg.sequence_len * 64,
)
return strat

View File

@@ -4,7 +4,7 @@ import hashlib
import logging
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import yaml
@@ -88,12 +88,21 @@ def prepare_dataset(cfg, tokenizer):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
train_dataset = load_pretraining_dataset(
path,
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
tokenizer,
cfg,
name=name,
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split="train", name=name),
tokenizer,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
@@ -383,9 +392,9 @@ def load_tokenized_prepared_datasets(
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
config_dataset=config_dataset,
dataset=ds,
tokenizer=tokenizer,
cfg=cfg,
dataset=ds,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
)
@@ -496,7 +505,12 @@ def load_prepare_datasets(
def get_dataset_wrapper(
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
config_dataset,
tokenizer,
cfg,
d_base_type,
dataset,
d_prompt_style=None,
):
dataset_wrapper = None
dataset_prompter = None
@@ -507,7 +521,8 @@ def get_dataset_wrapper(
}
if (
"input_ids" in dataset.features
isinstance(dataset, Dataset)
and "input_ids" in dataset.features
and "attention_mask" in dataset.features
and "labels" in dataset.features
):
@@ -765,69 +780,60 @@ def encode_pretraining(
return ret
def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
def wrap_pretraining_dataset(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
max_tokens=2048,
batch_size=1,
seed=42,
buffer_size=10_000,
):
if cfg.sample_packing:
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
pad_to_multiple_of=max_tokens * batch_size,
)
encode = functools.partial(
encode_packed_pretraining,
tokenizer,
collate_fn,
ds_wrapper_fn,
max_seq_length=max_tokens,
batch_size=cfg.micro_batch_size,
batch_size=batch_size,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train", name=name)
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
dataset = dataset.map(
encode,
batched=True,
batch_size=10_000,
input_columns="text",
batch_size=buffer_size,
# input_columns="text",
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
remove_columns=dataset.features.keys(),
desc="Encoding Pretraining",
)
return dataset
def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase,
collate_fn,
examples: List[str],
ds_wrapper: Callable,
examples: Dict[str, List],
max_seq_length: int = 2048,
batch_size: int = 4,
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
res = tokenizer(
examples,
truncation=True,
max_length=max_seq_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
attention_mask = [seq + [1] for seq in res["attention_mask"]]
tokenized_examples = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
train_dataset = Dataset.from_dict(tokenized_examples)
train_dataset = process_pretraining_datasets_for_packing(
train_dataset, max_seq_length
)
@@ -845,7 +851,14 @@ def encode_packed_pretraining(
for batch in sampler:
for data in batch:
features = train_dataset[data]
features["labels"] = features["input_ids"].copy()
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "overflow_to_sample_mapping" in features:
del features["overflow_to_sample_mapping"]
if "labels" not in features:
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)
for feature in features.keys():