streaming multipack for pretraining dataset (#959)

* [Feat] streaming multipack

* WIP make continued pretraining work w multipack

* fix up hadrcoding, lint

* fix dict check

* update test for updated pretraining multipack code

* fix hardcoded data collator fix for multipack pretraining

* fix the collator to be the max length for multipack pretraining

* don't bother with latest tag for test

* cleanup docker build/test

---------

Co-authored-by: jinwonkim93@github.com <jinwonkim>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
JinK
2024-01-06 12:13:21 +09:00
committed by GitHub
parent eb4c99431b
commit 553c80f79a
7 changed files with 282 additions and 12 deletions

View File

@@ -20,7 +20,6 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
is_latest: true
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.10" python_version: "3.10"
@@ -37,7 +36,7 @@ jobs:
images: winglian/axolotl images: winglian/axolotl
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: Build and export to Docker - name: Build Docker image
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
@@ -49,8 +48,7 @@ jobs:
file: ./docker/Dockerfile file: ./docker/Dockerfile
tags: | tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
- name: Unit Tests - name: Unit Tests w docker image
run: | run: |
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/

View File

@@ -0,0 +1,58 @@
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
strict: false
max_steps: 200
pretraining_dataset:
path: c4
name: en
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./model-out
sequence_len: 2048
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch:
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -60,6 +60,12 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False, default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."}, metadata={"help": "Use quadratic warmup for cosine scheduling."},
) )
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field( sample_packing: bool = field(
default=False, default=False,
metadata={"help": "Use sample packing for efficient training."}, metadata={"help": "Use sample packing for efficient training."},
@@ -157,7 +163,7 @@ class AxolotlTrainer(Trainer):
return self.lr_scheduler return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing: if self.args.sample_packing and not self.args.pretraining:
return MultipackBatchSampler( return MultipackBatchSampler(
RandomSampler(self.train_dataset), RandomSampler(self.train_dataset),
self.args.train_batch_size, self.args.train_batch_size,
@@ -193,7 +199,7 @@ class AxolotlTrainer(Trainer):
return super()._get_eval_sampler(eval_dataset) return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing: if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset train_dataset = self.train_dataset
train_dataset = train_dataset.remove_columns(["length"]) train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator data_collator = self.data_collator
@@ -768,6 +774,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs training_arguments_kwargs
) )
training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.neftune_noise_alpha is not None: if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[ training_arguments_kwargs[
@@ -808,7 +815,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset, eval_dataset=self.eval_dataset,
args=training_args, args=training_args,
data_collator=self.build_collator(**data_collator_kwargs), data_collator=self.build_collator(training_args, **data_collator_kwargs),
bench_data_collator=transformers.DataCollatorForSeq2Seq( bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer, self.tokenizer,
return_tensors="pt", return_tensors="pt",
@@ -829,7 +836,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return trainer return trainer
def build_collator(self, **kwargs): def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs):
if training_args.pretraining:
return None
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer) return MambaDataCollator(tokenizer=self.tokenizer)

View File

@@ -178,3 +178,24 @@ class MambaDataCollator:
"input_ids": input_ids, "input_ids": input_ids,
"labels": labels, "labels": labels,
} }
@dataclass
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
Collator for multipack specific to the using the BatchSampler
"""
def __call__(self, features, return_tensors=None):
chunked_data = {}
for feature in features.keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [(1) * np.array(item) for item in features[feature]]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [np.array(item) for item in features[feature]]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)

View File

@@ -2,6 +2,7 @@
import functools import functools
import hashlib import hashlib
import logging import logging
from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
@@ -14,6 +15,7 @@ from datasets import (
load_from_disk, load_from_disk,
) )
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
@@ -39,11 +41,14 @@ from axolotl.prompters import (
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
UnsupportedPrompter, UnsupportedPrompter,
) )
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.samplers.multipack import MultipackBatchSampler
from axolotl.utils.trainer import ( from axolotl.utils.trainer import (
calculate_total_num_steps, calculate_total_num_steps,
process_datasets_for_packing, process_datasets_for_packing,
process_pretraining_datasets_for_packing,
) )
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -64,9 +69,17 @@ def prepare_dataset(cfg, tokenizer):
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
else: else:
path = cfg.pretraining_dataset
name = None
if isinstance(cfg.pretraining_dataset, dict):
path = cfg.pretraining_dataset["path"]
name = cfg.pretraining_dataset["name"]
train_dataset = load_pretraining_dataset( train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset, path,
tokenizer, tokenizer,
cfg,
name=name,
max_tokens=cfg.sequence_len, max_tokens=cfg.sequence_len,
seed=cfg.seed or 42, seed=cfg.seed or 42,
) )
@@ -806,9 +819,27 @@ def encode_pretraining(
return ret return ret
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens) if cfg.sample_packing:
dataset = load_dataset(path, streaming=True, split="train") collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
)
encode = functools.partial(
encode_packed_pretraining,
tokenizer,
collate_fn,
max_seq_length=max_tokens,
batch_size=cfg.micro_batch_size,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train", name=name)
dataset = dataset.shuffle(seed=seed, buffer_size=10_000) dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
dataset = dataset.map( dataset = dataset.map(
encode, encode,
@@ -819,3 +850,63 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
remove_columns=dataset.features.keys(), remove_columns=dataset.features.keys(),
) )
return dataset return dataset
def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase,
collate_fn,
examples: List[str],
max_seq_length: int = 2048,
batch_size: int = 4,
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
res = tokenizer(
examples,
truncation=True,
max_length=max_seq_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
attention_mask = [seq + [1] for seq in res["attention_mask"]]
tokenized_examples = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
train_dataset = Dataset.from_dict(tokenized_examples)
train_dataset = process_pretraining_datasets_for_packing(
train_dataset, max_seq_length
)
sampler = MultipackBatchSampler(
RandomSampler(train_dataset),
batch_size=batch_size,
drop_last=True,
batch_max_len=batch_size * max_seq_length,
lengths=(
train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
)
chunked_data = defaultdict(list)
for data in sampler:
features = train_dataset[data]
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)
for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))
return chunked_data

View File

@@ -143,6 +143,16 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
return train_dataset, eval_dataset return train_dataset, eval_dataset
def process_pretraining_datasets_for_packing(train_dataset, sequence_len):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
train_dataset = train_dataset.filter(drop_long)
train_dataset = train_dataset.map(
add_position_ids,
)
return train_dataset
def calculate_total_num_steps(cfg, train_dataset, update=True): def calculate_total_num_steps(cfg, train_dataset, update=True):
if not cfg.total_num_tokens: if not cfg.total_num_tokens:
total_num_tokens = np.sum( total_num_tokens = np.sum(

View File

@@ -0,0 +1,82 @@
"""Module for testing streaming dataset sequence packing"""
import unittest
from functools import partial
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data import encode_packed_pretraining
class TestPacking(unittest.TestCase):
"""
Test class for packing streaming dataset sequences
"""
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.pad_token = "</s>"
self.max_seq_length = 2048
self.batch_size = 2
def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code
dataset = load_dataset(
"c4",
"en",
streaming=True,
)["train"]
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=self.max_seq_length,
)
encode = partial(
encode_packed_pretraining,
self.tokenizer,
collate_fn,
max_seq_length=self.max_seq_length,
batch_size=self.batch_size,
)
dataset = dataset.map(
encode,
batched=True,
input_columns="text",
remove_columns=dataset.features.keys(),
)
trainer_loader = DataLoader(
dataset,
batch_size=1,
collate_fn=None,
drop_last=True,
)
idx = 0
for data in trainer_loader:
if idx > 10:
break
assert data["input_ids"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
)
assert data["position_ids"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
)
assert data["labels"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
)
assert data["attention_mask"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
)
idx += 1
if __name__ == "__main__":
unittest.main()