use multi pack dataloader w random sampler

This commit is contained in:
Wing Lian
2023-07-17 23:44:14 -04:00
parent ffd96839cf
commit 3aba4c5d7c
4 changed files with 282 additions and 17 deletions

View File

@@ -13,6 +13,7 @@ einops
xformers
optimum
hf_transfer
numpy==1.24.4
# qlora things
bert-score==0.3.13
evaluate==0.4.0

View File

@@ -81,7 +81,6 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
@@ -113,9 +112,6 @@ class ConstantLengthDataset(IterableDataset):
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
@@ -124,7 +120,6 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
@@ -134,7 +129,6 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
@@ -161,12 +155,8 @@ class ConstantLengthDataset(IterableDataset):
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

@@ -0,0 +1,209 @@
# pylint: skip-file
from typing import Any, Callable, List
import numba
import numpy as np
from torch.utils.data import DistributedSampler
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
not_found = True
for idx in range(n):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
if not_found:
return False
return True
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
indices = np.argsort(a)[::-1]
a = a[indices]
bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
return bins_result, len(a)
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
start_index = 0
result = []
result_totseqs = []
while True:
# binary search [left, right)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# use length left
batch, tot_seqs = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
if len(batch) < n:
break
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
# add total seqs for all ranks
result_totseqs.append(tot_seqs)
return result, result_totseqs, s, len(result) * c * n
class MultipackDistributedDataloader:
"""Unpadded data loading using Multipack.
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
"""
def __init__(
self,
dataset: Any,
collate_fn: Callable,
seq_max_length: int = 2048,
batch_size: int = 1,
sampler: DistributedSampler = None,
seed: int = 0,
):
# Dataset
self.dataset = dataset
self.lengths: np.ndarray = np.array(
[len(sample["input_ids"]) for sample in self.dataset]
)
assert isinstance(self.lengths, np.ndarray)
self.sampler = sampler
self.batch_size = batch_size
self.seq_max_length = seq_max_length
self.batch_max_length = batch_size * seq_max_length
self.collate_fn = collate_fn
self.num_replicas = 1
self.rank = 0
# Seed
self.seed = seed
# Epoch
self.epoch = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
def set_epoch(self, epoch: int):
self.epoch = epoch
def generate_batches(self, set_stats=False):
indices = [idx for idx in self.sampler]
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, totseqs, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
c=self.batch_max_length,
n=self.num_replicas,
)
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batches, totseqs
def __iter__(self):
all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys()
for batch in all_batches:
concatenated = {}
batched = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "attention_mask":
arrays = [
(idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched)
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in batched if feature in item
]
concatenated[feature] = np.concatenate(arrays)
num_chunks = int(
np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
)
chunked_data = []
for i in range(num_chunks):
chunk = {
feature: array[
i * self.seq_max_length : (i + 1) * self.seq_max_length
]
for feature, array in concatenated.items()
}
chunked_data.append(chunk)
yield self.collate_fn(chunked_data)
def __len__(self):
batches, _ = self.generate_batches()
return len(batches)
def num_batches(self):
batches, _ = self.generate_batches()
return len(batches)
def efficiency(self):
return self.eff_total_used / self.eff_total_slots

View File

@@ -1,5 +1,4 @@
"""Module containing the Trainer class and related functions"""
import importlib
import logging
import math
@@ -7,13 +6,15 @@ import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from typing import Optional, Union
import bitsandbytes as bnb
import torch.cuda
import transformers
from datasets import Dataset
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names
@@ -21,7 +22,7 @@ from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.schedulers import (
InterpolatingLogScheduler,
get_cosine_schedule_with_quadratic_warmup,
@@ -40,6 +41,14 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
class AxolotlTrainer(Trainer):
@@ -77,6 +86,32 @@ class AxolotlTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
return MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
)
return super().get_train_dataloader()
def get_eval_dataloader(
self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
eval_sampler = self._get_eval_sampler(eval_dataset)
return MultipackDistributedDataloader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=eval_sampler,
)
return super().get_eval_dataloader(eval_dataset)
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
@@ -108,9 +143,36 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cfg.sample_packing:
sampler = DistributedSampler(
train_dataset,
num_replicas=1,
rank=0,
seed=cfg.seed,
)
data_loader = MultipackDistributedDataloader(
train_dataset,
batch_size=cfg.micro_batch_size,
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
collate_fn=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding="longest",
),
sampler=sampler,
)
total_num_steps = int(
math.ceil(
len(data_loader)
* cfg.micro_batch_size
* cfg.num_epochs
/ cfg.batch_size
)
)
else:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
warmup_steps = (
cfg.warmup_steps
if cfg.warmup_steps is not None
@@ -191,6 +253,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps
* cfg.num_epochs, # this is helpful in case we don't actually know total # of steps
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None
@@ -222,6 +286,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
**training_arguments_kwargs,
)
@@ -347,7 +412,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
data_collator=DataCollatorForSeq2Seq(
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,