diff --git a/.github/workflows/tests-docker.yml b/.github/workflows/tests-docker.yml index 0aba6d505..e93884e64 100644 --- a/.github/workflows/tests-docker.yml +++ b/.github/workflows/tests-docker.yml @@ -20,7 +20,6 @@ jobs: python_version: "3.10" pytorch: 2.0.1 axolotl_extras: - is_latest: true - cuda: 121 cuda_version: 12.1.0 python_version: "3.10" @@ -37,7 +36,7 @@ jobs: images: winglian/axolotl - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - - name: Build and export to Docker + - name: Build Docker image uses: docker/build-push-action@v5 with: context: . @@ -49,8 +48,7 @@ jobs: file: ./docker/Dockerfile tags: | ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} - ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} labels: ${{ steps.metadata.outputs.labels }} - - name: Unit Tests + - name: Unit Tests w docker image run: | docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml new file mode 100644 index 000000000..dfd1bfca2 --- /dev/null +++ b/examples/tiny-llama/pretrain.yml @@ -0,0 +1,58 @@ +base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 + +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +max_steps: 200 +pretraining_dataset: + path: c4 + name: en +dataset_prepared_path: +val_set_size: 0.0 +output_dir: ./model-out + +sequence_len: 2048 +sample_packing: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 465cfa1af..26cc91ed5 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -60,6 +60,12 @@ class AxolotlTrainingArguments(TrainingArguments): default=False, metadata={"help": "Use quadratic warmup for cosine scheduling."}, ) + pretraining: bool = field( + default=False, + metadata={ + "help": "Indicates to trainer whether we are doing continued pretraining." + }, + ) sample_packing: bool = field( default=False, metadata={"help": "Use sample packing for efficient training."}, @@ -157,7 +163,7 @@ class AxolotlTrainer(Trainer): return self.lr_scheduler def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing: + if self.args.sample_packing and not self.args.pretraining: return MultipackBatchSampler( RandomSampler(self.train_dataset), self.args.train_batch_size, @@ -193,7 +199,7 @@ class AxolotlTrainer(Trainer): return super()._get_eval_sampler(eval_dataset) def get_train_dataloader(self) -> DataLoader: - if self.args.sample_packing: + if self.args.sample_packing and not self.args.pretraining: train_dataset = self.train_dataset train_dataset = train_dataset.remove_columns(["length"]) data_collator = self.data_collator @@ -768,6 +774,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs ) training_arguments_kwargs["model_type"] = self.cfg.model_config_type + training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) if self.cfg.neftune_noise_alpha is not None: training_arguments_kwargs[ @@ -808,7 +815,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, args=training_args, - data_collator=self.build_collator(**data_collator_kwargs), + data_collator=self.build_collator(training_args, **data_collator_kwargs), bench_data_collator=transformers.DataCollatorForSeq2Seq( self.tokenizer, return_tensors="pt", @@ -829,7 +836,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return trainer - def build_collator(self, **kwargs): + def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs): + if training_args.pretraining: + return None + if self.cfg.model_config_type == "mamba": return MambaDataCollator(tokenizer=self.tokenizer) diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index 0f0eb5a95..b9c1c3b3c 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -178,3 +178,24 @@ class MambaDataCollator: "input_ids": input_ids, "labels": labels, } + + +@dataclass +class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): + """ + Collator for multipack specific to the using the BatchSampler + """ + + def __call__(self, features, return_tensors=None): + chunked_data = {} + for feature in features.keys(): + if feature == "length": + continue + if feature == "attention_mask": + arrays = [(1) * np.array(item) for item in features[feature]] + chunked_data[feature] = np.concatenate(arrays) + else: + arrays = [np.array(item) for item in features[feature]] + chunked_data[feature] = np.concatenate(arrays) + features = [chunked_data] + return super().__call__(features, return_tensors=return_tensors) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 5c41d16fe..b3c7606eb 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -2,6 +2,7 @@ import functools import hashlib import logging +from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple, Union @@ -14,6 +15,7 @@ from datasets import ( load_from_disk, ) from huggingface_hub import hf_hub_download +from torch.utils.data import RandomSampler from transformers import PreTrainedTokenizerBase from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH @@ -39,11 +41,14 @@ from axolotl.prompters import ( SummarizeTLDRPrompter, UnsupportedPrompter, ) +from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first +from axolotl.utils.samplers.multipack import MultipackBatchSampler from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, + process_pretraining_datasets_for_packing, ) LOG = logging.getLogger("axolotl") @@ -64,9 +69,17 @@ def prepare_dataset(cfg, tokenizer): tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) else: + path = cfg.pretraining_dataset + name = None + if isinstance(cfg.pretraining_dataset, dict): + path = cfg.pretraining_dataset["path"] + name = cfg.pretraining_dataset["name"] + train_dataset = load_pretraining_dataset( - cfg.pretraining_dataset, + path, tokenizer, + cfg, + name=name, max_tokens=cfg.sequence_len, seed=cfg.seed or 42, ) @@ -806,9 +819,27 @@ def encode_pretraining( return ret -def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): - encode = functools.partial(encode_pretraining, tokenizer, max_tokens) - dataset = load_dataset(path, streaming=True, split="train") +def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42): + if cfg.sample_packing: + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + padding=True, + pad_to_multiple_of=max_tokens * cfg.micro_batch_size, + ) + encode = functools.partial( + encode_packed_pretraining, + tokenizer, + collate_fn, + max_seq_length=max_tokens, + batch_size=cfg.micro_batch_size, + ) + # set this to 1 so downstream data_loader doesn't try to increase the batch again + cfg.micro_batch_size = 1 + else: + encode = functools.partial(encode_pretraining, tokenizer, max_tokens) + + dataset = load_dataset(path, streaming=True, split="train", name=name) dataset = dataset.shuffle(seed=seed, buffer_size=10_000) dataset = dataset.map( encode, @@ -819,3 +850,63 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): remove_columns=dataset.features.keys(), ) return dataset + + +def encode_packed_pretraining( + tokenizer: PreTrainedTokenizerBase, + collate_fn, + examples: List[str], + max_seq_length: int = 2048, + batch_size: int = 4, +) -> Dict[str, List]: + # pylint: disable=duplicate-code + # tokenize all the examples + # rows get split with stride (overlap) + res = tokenizer( + examples, + truncation=True, + max_length=max_seq_length - 1, + add_special_tokens=True, + return_overflowing_tokens=True, + stride=256, + ) + + input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]] + attention_mask = [seq + [1] for seq in res["attention_mask"]] + + tokenized_examples = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + train_dataset = Dataset.from_dict(tokenized_examples) + train_dataset = process_pretraining_datasets_for_packing( + train_dataset, max_seq_length + ) + + sampler = MultipackBatchSampler( + RandomSampler(train_dataset), + batch_size=batch_size, + drop_last=True, + batch_max_len=batch_size * max_seq_length, + lengths=( + train_dataset.data.column("position_ids") + .to_pandas() + .apply(lambda x: x[-1] + 1) + .values + ), + ) + + chunked_data = defaultdict(list) + + for data in sampler: + features = train_dataset[data] + features["labels"] = features["input_ids"].copy() + collated_features = collate_fn(features) + + for feature in features.keys(): + if feature == "length": + continue + chunked_data[feature].append(collated_features[feature].squeeze(0)) + + return chunked_data diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d975bb9a2..3139f5600 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -143,6 +143,16 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): return train_dataset, eval_dataset +def process_pretraining_datasets_for_packing(train_dataset, sequence_len): + drop_long = partial(drop_long_seq, sequence_len=sequence_len) + + train_dataset = train_dataset.filter(drop_long) + train_dataset = train_dataset.map( + add_position_ids, + ) + return train_dataset + + def calculate_total_num_steps(cfg, train_dataset, update=True): if not cfg.total_num_tokens: total_num_tokens = np.sum( diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py new file mode 100644 index 000000000..a47e3983f --- /dev/null +++ b/tests/test_packed_pretraining.py @@ -0,0 +1,82 @@ +"""Module for testing streaming dataset sequence packing""" +import unittest +from functools import partial + +import torch +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + +from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.data import encode_packed_pretraining + + +class TestPacking(unittest.TestCase): + """ + Test class for packing streaming dataset sequences + """ + + def setUp(self) -> None: + # pylint: disable=duplicate-code + self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") + self.tokenizer.pad_token = "" + self.max_seq_length = 2048 + self.batch_size = 2 + + def test_packing_stream_dataset(self): + # pylint: disable=duplicate-code + dataset = load_dataset( + "c4", + "en", + streaming=True, + )["train"] + + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( + self.tokenizer, + return_tensors="pt", + padding=True, + pad_to_multiple_of=self.max_seq_length, + ) + + encode = partial( + encode_packed_pretraining, + self.tokenizer, + collate_fn, + max_seq_length=self.max_seq_length, + batch_size=self.batch_size, + ) + + dataset = dataset.map( + encode, + batched=True, + input_columns="text", + remove_columns=dataset.features.keys(), + ) + + trainer_loader = DataLoader( + dataset, + batch_size=1, + collate_fn=None, + drop_last=True, + ) + idx = 0 + for data in trainer_loader: + if idx > 10: + break + assert data["input_ids"].shape == torch.Size( + [1, self.batch_size * self.max_seq_length] + ) + assert data["position_ids"].shape == torch.Size( + [1, self.batch_size * self.max_seq_length] + ) + assert data["labels"].shape == torch.Size( + [1, self.batch_size * self.max_seq_length] + ) + assert data["attention_mask"].shape == torch.Size( + [1, self.batch_size * self.max_seq_length] + ) + idx += 1 + + +if __name__ == "__main__": + unittest.main()