streaming multipack for pretraining dataset (#959)
* [Feat] streaming multipack * WIP make continued pretraining work w multipack * fix up hadrcoding, lint * fix dict check * update test for updated pretraining multipack code * fix hardcoded data collator fix for multipack pretraining * fix the collator to be the max length for multipack pretraining * don't bother with latest tag for test * cleanup docker build/test --------- Co-authored-by: jinwonkim93@github.com <jinwonkim> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
6
.github/workflows/tests-docker.yml
vendored
6
.github/workflows/tests-docker.yml
vendored
@@ -20,7 +20,6 @@ jobs:
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.10"
|
||||
@@ -37,7 +36,7 @@ jobs:
|
||||
images: winglian/axolotl
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build and export to Docker
|
||||
- name: Build Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
@@ -49,8 +48,7 @@ jobs:
|
||||
file: ./docker/Dockerfile
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
- name: Unit Tests
|
||||
- name: Unit Tests w docker image
|
||||
run: |
|
||||
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
|
||||
58
examples/tiny-llama/pretrain.yml
Normal file
58
examples/tiny-llama/pretrain.yml
Normal file
@@ -0,0 +1,58 @@
|
||||
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
max_steps: 200
|
||||
pretraining_dataset:
|
||||
path: c4
|
||||
name: en
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./model-out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch:
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -60,6 +60,12 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
pretraining: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||
},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
@@ -157,7 +163,7 @@ class AxolotlTrainer(Trainer):
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
return MultipackBatchSampler(
|
||||
RandomSampler(self.train_dataset),
|
||||
self.args.train_batch_size,
|
||||
@@ -193,7 +199,7 @@ class AxolotlTrainer(Trainer):
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.args.sample_packing:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = self.train_dataset
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
@@ -768,6 +774,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs
|
||||
)
|
||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||
|
||||
if self.cfg.neftune_noise_alpha is not None:
|
||||
training_arguments_kwargs[
|
||||
@@ -808,7 +815,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
args=training_args,
|
||||
data_collator=self.build_collator(**data_collator_kwargs),
|
||||
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
@@ -829,7 +836,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
return trainer
|
||||
|
||||
def build_collator(self, **kwargs):
|
||||
def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs):
|
||||
if training_args.pretraining:
|
||||
return None
|
||||
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||
|
||||
|
||||
@@ -178,3 +178,24 @@ class MambaDataCollator:
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
Collator for multipack specific to the using the BatchSampler
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
chunked_data = {}
|
||||
for feature in features.keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [(1) * np.array(item) for item in features[feature]]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [np.array(item) for item in features[feature]]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import functools
|
||||
import hashlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
@@ -14,6 +15,7 @@ from datasets import (
|
||||
load_from_disk,
|
||||
)
|
||||
from huggingface_hub import hf_hub_download
|
||||
from torch.utils.data import RandomSampler
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
@@ -39,11 +41,14 @@ from axolotl.prompters import (
|
||||
SummarizeTLDRPrompter,
|
||||
UnsupportedPrompter,
|
||||
)
|
||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
from axolotl.utils.samplers.multipack import MultipackBatchSampler
|
||||
from axolotl.utils.trainer import (
|
||||
calculate_total_num_steps,
|
||||
process_datasets_for_packing,
|
||||
process_pretraining_datasets_for_packing,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
@@ -64,9 +69,17 @@ def prepare_dataset(cfg, tokenizer):
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
else:
|
||||
path = cfg.pretraining_dataset
|
||||
name = None
|
||||
if isinstance(cfg.pretraining_dataset, dict):
|
||||
path = cfg.pretraining_dataset["path"]
|
||||
name = cfg.pretraining_dataset["name"]
|
||||
|
||||
train_dataset = load_pretraining_dataset(
|
||||
cfg.pretraining_dataset,
|
||||
path,
|
||||
tokenizer,
|
||||
cfg,
|
||||
name=name,
|
||||
max_tokens=cfg.sequence_len,
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
@@ -806,9 +819,27 @@ def encode_pretraining(
|
||||
return ret
|
||||
|
||||
|
||||
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||
dataset = load_dataset(path, streaming=True, split="train")
|
||||
def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
|
||||
if cfg.sample_packing:
|
||||
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
|
||||
)
|
||||
encode = functools.partial(
|
||||
encode_packed_pretraining,
|
||||
tokenizer,
|
||||
collate_fn,
|
||||
max_seq_length=max_tokens,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
)
|
||||
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
||||
cfg.micro_batch_size = 1
|
||||
else:
|
||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||
|
||||
dataset = load_dataset(path, streaming=True, split="train", name=name)
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
||||
dataset = dataset.map(
|
||||
encode,
|
||||
@@ -819,3 +850,63 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
||||
remove_columns=dataset.features.keys(),
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
def encode_packed_pretraining(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
collate_fn,
|
||||
examples: List[str],
|
||||
max_seq_length: int = 2048,
|
||||
batch_size: int = 4,
|
||||
) -> Dict[str, List]:
|
||||
# pylint: disable=duplicate-code
|
||||
# tokenize all the examples
|
||||
# rows get split with stride (overlap)
|
||||
res = tokenizer(
|
||||
examples,
|
||||
truncation=True,
|
||||
max_length=max_seq_length - 1,
|
||||
add_special_tokens=True,
|
||||
return_overflowing_tokens=True,
|
||||
stride=256,
|
||||
)
|
||||
|
||||
input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
|
||||
attention_mask = [seq + [1] for seq in res["attention_mask"]]
|
||||
|
||||
tokenized_examples = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
train_dataset = Dataset.from_dict(tokenized_examples)
|
||||
train_dataset = process_pretraining_datasets_for_packing(
|
||||
train_dataset, max_seq_length
|
||||
)
|
||||
|
||||
sampler = MultipackBatchSampler(
|
||||
RandomSampler(train_dataset),
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=batch_size * max_seq_length,
|
||||
lengths=(
|
||||
train_dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
),
|
||||
)
|
||||
|
||||
chunked_data = defaultdict(list)
|
||||
|
||||
for data in sampler:
|
||||
features = train_dataset[data]
|
||||
features["labels"] = features["input_ids"].copy()
|
||||
collated_features = collate_fn(features)
|
||||
|
||||
for feature in features.keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
||||
|
||||
return chunked_data
|
||||
|
||||
@@ -143,6 +143,16 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def process_pretraining_datasets_for_packing(train_dataset, sequence_len):
|
||||
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
|
||||
|
||||
train_dataset = train_dataset.filter(drop_long)
|
||||
train_dataset = train_dataset.map(
|
||||
add_position_ids,
|
||||
)
|
||||
return train_dataset
|
||||
|
||||
|
||||
def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
if not cfg.total_num_tokens:
|
||||
total_num_tokens = np.sum(
|
||||
|
||||
82
tests/test_packed_pretraining.py
Normal file
82
tests/test_packed_pretraining.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Module for testing streaming dataset sequence packing"""
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.data import encode_packed_pretraining
|
||||
|
||||
|
||||
class TestPacking(unittest.TestCase):
|
||||
"""
|
||||
Test class for packing streaming dataset sequences
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.pad_token = "</s>"
|
||||
self.max_seq_length = 2048
|
||||
self.batch_size = 2
|
||||
|
||||
def test_packing_stream_dataset(self):
|
||||
# pylint: disable=duplicate-code
|
||||
dataset = load_dataset(
|
||||
"c4",
|
||||
"en",
|
||||
streaming=True,
|
||||
)["train"]
|
||||
|
||||
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
pad_to_multiple_of=self.max_seq_length,
|
||||
)
|
||||
|
||||
encode = partial(
|
||||
encode_packed_pretraining,
|
||||
self.tokenizer,
|
||||
collate_fn,
|
||||
max_seq_length=self.max_seq_length,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
|
||||
dataset = dataset.map(
|
||||
encode,
|
||||
batched=True,
|
||||
input_columns="text",
|
||||
remove_columns=dataset.features.keys(),
|
||||
)
|
||||
|
||||
trainer_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
collate_fn=None,
|
||||
drop_last=True,
|
||||
)
|
||||
idx = 0
|
||||
for data in trainer_loader:
|
||||
if idx > 10:
|
||||
break
|
||||
assert data["input_ids"].shape == torch.Size(
|
||||
[1, self.batch_size * self.max_seq_length]
|
||||
)
|
||||
assert data["position_ids"].shape == torch.Size(
|
||||
[1, self.batch_size * self.max_seq_length]
|
||||
)
|
||||
assert data["labels"].shape == torch.Size(
|
||||
[1, self.batch_size * self.max_seq_length]
|
||||
)
|
||||
assert data["attention_mask"].shape == torch.Size(
|
||||
[1, self.batch_size * self.max_seq_length]
|
||||
)
|
||||
idx += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user