From 3aba4c5d7cc579ed933aff1fcf59b3ba3953614b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 17 Jul 2023 23:44:14 -0400 Subject: [PATCH] use multi pack dataloader w random sampler --- requirements.txt | 1 + src/axolotl/datasets.py | 10 -- src/axolotl/utils/dataloader.py | 209 ++++++++++++++++++++++++++++++++ src/axolotl/utils/trainer.py | 79 ++++++++++-- 4 files changed, 282 insertions(+), 17 deletions(-) create mode 100644 src/axolotl/utils/dataloader.py diff --git a/requirements.txt b/requirements.txt index 98a57c66a..cd7d9f033 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ einops xformers optimum hf_transfer +numpy==1.24.4 # qlora things bert-score==0.3.13 evaluate==0.4.0 diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 4376fb18a..bc137d238 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -81,7 +81,6 @@ class ConstantLengthDataset(IterableDataset): "input_ids": [], "attention_mask": [], "labels": [], - "position_ids": [], } buffer_len = 0 for dataset in self.datasets: @@ -113,9 +112,6 @@ class ConstantLengthDataset(IterableDataset): attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ : self.seq_length ] - position_ids = torch.cat(buffer["position_ids"], dim=-1)[ - : self.seq_length - ] labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] if labels.size() == input_ids.size() and ( attention_mask.size() == input_ids.size() @@ -124,7 +120,6 @@ class ConstantLengthDataset(IterableDataset): "input_ids": input_ids, "labels": labels, "attention_mask": attention_mask, - "position_ids": position_ids, } else: LOG.warning( @@ -134,7 +129,6 @@ class ConstantLengthDataset(IterableDataset): "input_ids": [], "attention_mask": [], "labels": [], - "position_ids": [], } buffer_len = 0 idx = 1 @@ -161,12 +155,8 @@ class ConstantLengthDataset(IterableDataset): labels_with_concat = torch.tensor( labels, dtype=self.tokens_dtype ) - position_ids = torch.arange( - len(input_ids), dtype=self.tokens_dtype - ) buffer["input_ids"].append(input_ids_with_concat) buffer["attention_mask"].append(attention_mask_with_concat) buffer["labels"].append(labels_with_concat) - buffer["position_ids"].append(position_ids) buffer_len += len(input_ids) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py new file mode 100644 index 000000000..2a95749d2 --- /dev/null +++ b/src/axolotl/utils/dataloader.py @@ -0,0 +1,209 @@ +# pylint: skip-file + +from typing import Any, Callable, List + +import numba +import numpy as np +from torch.utils.data import DistributedSampler + + +@numba.njit +def ffd_check(a: np.ndarray, c: int, n: int): + # First-fit-decreasing bin packing + # Check if a[] could fit in n bins with capacity c + # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing + + a = np.sort(a)[::-1] + bins = np.full((n,), c, dtype=a.dtype) + for size in a: + not_found = True + for idx in range(n): + if bins[idx] >= size: + bins[idx] -= size + not_found = False + break + + if not_found: + return False + + return True + + +@numba.njit +def ffd_with_result(a: np.ndarray, c: int, start_index: int): + # First-fit-decreasing bin packing (with result return) + + indices = np.argsort(a)[::-1] + a = a[indices] + + bins: List[Any] = [] + bins_result: List[Any] = [] + for a_id, size in enumerate(a): + add_new = True + for idx in range(len(bins)): + if bins[idx] >= size: + bins[idx] -= size + bins_result[idx].append(indices[a_id] + start_index) + add_new = False + break + + if add_new: + bins.append(c - size) + bins_result.append([indices[a_id] + start_index]) + + return bins_result, len(a) + + +@numba.njit +def allocate( + lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int +): + # Dynamic batch allocator, similar to Multifit + # https://en.wikipedia.org/wiki/Multifit_algorithm + # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) + + s = 0 + start_index = 0 + result = [] + result_totseqs = [] + + while True: + # binary search [left, right) + left = 1 + right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") + + while right - left > 1: + mid = (left + right) // 2 + if ffd_check(lengths[start_index : start_index + mid], c, n): + left = mid + else: + right = mid + + # use length left + batch, tot_seqs = ffd_with_result( + lengths[start_index : start_index + left], c, start_index + ) + if len(batch) < n: + break + + start_index += left + s = lengths_cumsum[start_index - 1] + + # add local rank + result.append(batch[rank]) + # add total seqs for all ranks + result_totseqs.append(tot_seqs) + + return result, result_totseqs, s, len(result) * c * n + + +class MultipackDistributedDataloader: + """Unpadded data loading using Multipack. + Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard. + """ + + def __init__( + self, + dataset: Any, + collate_fn: Callable, + seq_max_length: int = 2048, + batch_size: int = 1, + sampler: DistributedSampler = None, + seed: int = 0, + ): + # Dataset + self.dataset = dataset + self.lengths: np.ndarray = np.array( + [len(sample["input_ids"]) for sample in self.dataset] + ) + assert isinstance(self.lengths, np.ndarray) + + self.sampler = sampler + self.batch_size = batch_size + self.seq_max_length = seq_max_length + self.batch_max_length = batch_size * seq_max_length + self.collate_fn = collate_fn + + self.num_replicas = 1 + self.rank = 0 + + # Seed + self.seed = seed + + # Epoch + self.epoch = 0 + + # statistics + self.eff_total_used = 0 + self.eff_total_slots = 0 + + def set_epoch(self, epoch: int): + self.epoch = epoch + + def generate_batches(self, set_stats=False): + indices = [idx for idx in self.sampler] + + lengths = self.lengths[indices] + lengths_cumsum = np.cumsum(lengths) + + batches, totseqs, total_used, total_slots = allocate( + lengths=lengths, + lengths_cumsum=lengths_cumsum, + rank=self.rank, + c=self.batch_max_length, + n=self.num_replicas, + ) + + batches = [[indices[b_idx] for b_idx in batch] for batch in batches] + + # statistics + if set_stats: + self.eff_total_used += total_used + self.eff_total_slots += total_slots + + return batches, totseqs + + def __iter__(self): + all_batches, _ = self.generate_batches(set_stats=True) + features = self.dataset.features.keys() + for batch in all_batches: + concatenated = {} + batched = [self.dataset[batch_idx] for batch_idx in batch] + for feature in features: + if feature == "attention_mask": + arrays = [ + (idx + 1) * np.array(item[feature]) + for idx, item in enumerate(batched) + if feature in item + ] + concatenated[feature] = np.concatenate(arrays) + else: + arrays = [ + np.array(item[feature]) for item in batched if feature in item + ] + concatenated[feature] = np.concatenate(arrays) + num_chunks = int( + np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length) + ) + chunked_data = [] + + for i in range(num_chunks): + chunk = { + feature: array[ + i * self.seq_max_length : (i + 1) * self.seq_max_length + ] + for feature, array in concatenated.items() + } + chunked_data.append(chunk) + yield self.collate_fn(chunked_data) + + def __len__(self): + batches, _ = self.generate_batches() + return len(batches) + + def num_batches(self): + batches, _ = self.generate_batches() + return len(batches) + + def efficiency(self): + return self.eff_total_used / self.eff_total_slots diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 9c398e21e..38d5f0e3b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,5 +1,4 @@ """Module containing the Trainer class and related functions""" - import importlib import logging import math @@ -7,13 +6,15 @@ import os import sys from dataclasses import dataclass, field from pathlib import Path -from typing import Optional +from typing import Optional, Union import bitsandbytes as bnb import torch.cuda import transformers +from datasets import Dataset from torch import nn from torch.optim.lr_scheduler import OneCycleLR +from torch.utils.data import DataLoader, DistributedSampler from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_pt_utils import get_parameter_names @@ -21,7 +22,7 @@ from axolotl.utils.callbacks import ( SaveBetterTransformerModelCallback, SavePeftModelCallback, ) -from axolotl.utils.collators import DataCollatorForSeq2Seq +from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.schedulers import ( InterpolatingLogScheduler, get_cosine_schedule_with_quadratic_warmup, @@ -40,6 +41,14 @@ class AxolotlTrainingArguments(TrainingArguments): default=False, metadata={"help": "Use quadratic warmup for cosine scheduling."}, ) + sample_packing: bool = field( + default=False, + metadata={"help": "Use sample packing for efficient training."}, + ) + max_seq_length: int = field( + default=2048, + metadata={"help": "The maximum sequence length the model can handle"}, + ) class AxolotlTrainer(Trainer): @@ -77,6 +86,32 @@ class AxolotlTrainer(Trainer): return super().create_scheduler(num_training_steps, optimizer) return self.lr_scheduler + def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]: + if self.args.sample_packing: + train_sampler = self._get_train_sampler() + return MultipackDistributedDataloader( + self.train_dataset, + batch_size=self._train_batch_size, + seq_max_length=self.args.max_seq_length, + collate_fn=self.data_collator, + sampler=train_sampler, + ) + return super().get_train_dataloader() + + def get_eval_dataloader( + self, eval_dataset: Optional[Dataset] = None + ) -> Union[DataLoader, MultipackDistributedDataloader]: + if self.args.sample_packing: + eval_sampler = self._get_eval_sampler(eval_dataset) + return MultipackDistributedDataloader( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + seq_max_length=self.args.max_seq_length, + collate_fn=self.data_collator, + sampler=eval_sampler, + ) + return super().get_eval_dataloader(eval_dataset) + class OneCycleLRSchedulerTrainer(AxolotlTrainer): """ @@ -108,9 +143,36 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): - total_num_steps = int( - math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) - ) + if cfg.sample_packing: + sampler = DistributedSampler( + train_dataset, + num_replicas=1, + rank=0, + seed=cfg.seed, + ) + data_loader = MultipackDistributedDataloader( + train_dataset, + batch_size=cfg.micro_batch_size, + seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len, + collate_fn=transformers.DataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + padding="longest", + ), + sampler=sampler, + ) + total_num_steps = int( + math.ceil( + len(data_loader) + * cfg.micro_batch_size + * cfg.num_epochs + / cfg.batch_size + ) + ) + else: + total_num_steps = int( + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + ) warmup_steps = ( cfg.warmup_steps if cfg.warmup_steps is not None @@ -191,6 +253,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg + max_steps=total_num_steps + * cfg.num_epochs, # this is helpful in case we don't actually know total # of steps per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None @@ -222,6 +286,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine", weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, + sample_packing=cfg.sample_packing if cfg.sample_packing else False, **training_arguments_kwargs, ) @@ -347,7 +412,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): train_dataset=train_dataset, eval_dataset=eval_dataset, args=training_args, - data_collator=DataCollatorForSeq2Seq( + data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, return_tensors="pt", **data_collator_kwargs,