From a6b37bdeb41107ec8b9cc150e3f73cee668e819a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 29 May 2024 11:51:18 -0400 Subject: [PATCH] revert multipack batch sampler changes (#1672) * revert multipack batch sampler changes * fix default val for drop_last --- src/axolotl/utils/samplers/multipack.py | 245 +++++++++++++++--------- 1 file changed, 159 insertions(+), 86 deletions(-) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 1d025ca2d..957ca5746 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -1,64 +1,105 @@ +# pylint: skip-file """ Multipack Batch Sampler """ import logging -from concurrent.futures import ProcessPoolExecutor -from multiprocessing import cpu_count +import math +import os +from typing import Any, Iterable, List, Union import numba import numpy as np -from torch.utils.data import BatchSampler +from torch.utils.data import BatchSampler, Sampler LOG = logging.getLogger("axolotl.utils.samplers.multipack") -# First-fit-decreasing bin packing. @numba.njit -def pack_group(items, group_offset, bin_capacity, max_items_per_bin): - idxs = np.argsort(items)[::-1] - sorted_items = items[idxs] - num_bins = len(items) - bins = np.full(num_bins, bin_capacity, dtype=np.int32) - bin_counts = np.zeros(num_bins, dtype=np.int32) - group_packing = np.full((num_bins, max_items_per_bin), -1, dtype=np.int32) +def ffd_check(a: np.ndarray, c: int, n: int): + # First-fit-decreasing bin packing + # Check if a[] could fit in n bins with capacity c + # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing - for idx, item in enumerate(sorted_items): - global_idx = idxs[idx] + group_offset - - placed = False - for i in range(num_bins): - if bins[i] >= item and bin_counts[i] < max_items_per_bin: - bins[i] -= item - group_packing[i, bin_counts[i]] = global_idx - bin_counts[i] += 1 - placed = True + a = np.sort(a)[::-1] + bins = np.full((n,), c, dtype=a.dtype) + for size in a: + not_found = True + for idx in range(n): + if bins[idx] >= size: + bins[idx] -= size + not_found = False break - if not placed: - raise ValueError( - f"Item could not be packed. Try increasing cfg.sample_packing_bin_size ({max_items_per_bin})." - ) + if not_found: + return False - return group_packing + return True -def pack(items, bin_capacity, group_size, max_items_per_bin): - num_items = len(items) - num_processes = max(1, min(num_items // group_size, cpu_count())) - tasks = [ - (items[i : i + group_size], i, bin_capacity, max_items_per_bin) - for i in range(0, num_items, group_size) - ] +@numba.njit +def ffd_with_result(a: np.ndarray, c: int, start_index: int): + # First-fit-decreasing bin packing (with result return) - packed_bins = [] - with ProcessPoolExecutor(max_workers=num_processes) as executor: - for group_packing in executor.map(pack_group, *zip(*tasks)): - for bin_pack in group_packing: - filtered_pack = bin_pack[bin_pack != -1] - if filtered_pack.size > 0: - packed_bins.append(filtered_pack.tolist()) + indices = np.argsort(a)[::-1] + a = a[indices] - return packed_bins + bins: List[Any] = [] + bins_result: List[Any] = [] + for a_id, size in enumerate(a): + add_new = True + for idx in range(len(bins)): + if bins[idx] >= size: + bins[idx] -= size + bins_result[idx].append(indices[a_id] + start_index) + add_new = False + break + + if add_new: + bins.append(c - size) + bins_result.append([indices[a_id] + start_index]) + + return bins_result + + +@numba.njit +def allocate( + lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int +): + # Dynamic batch allocator, similar to Multifit + # https://en.wikipedia.org/wiki/Multifit_algorithm + # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) + + s = 0 + start_index = 0 + result = [] + + while True: + # binary search [l, r) + left = 1 + right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") + + while right - left > 1: + mid = (left + right) // 2 + if ffd_check(lengths[start_index : start_index + mid], c, n): + left = mid + else: + right = mid + + # use length l + batch = ffd_with_result( + lengths[start_index : start_index + left], c, start_index + ) + assert len(batch) <= n + if len(batch) < n: + break + + start_index += left + s = lengths_cumsum[start_index - 1] + + # add local rank + result.append(batch[rank]) + + return result, s, len(result) * c * n class MultipackBatchSampler(BatchSampler): @@ -68,63 +109,95 @@ class MultipackBatchSampler(BatchSampler): def __init__( self, - sampler, - lengths, - batch_max_len, - batch_size, - group_size=100_000, - bin_size=200, - drop_last=False, + sampler: Union[Sampler[int], Iterable[int]], + batch_size: int, + batch_max_len: int, + lengths: np.ndarray, + packing_efficiency_estimate: float = 1.0, + drop_last: bool = False, + **kwargs, ): - self.sampler = sampler - self.lengths = np.array(lengths, dtype=np.int32) - self.batch_max_len = batch_max_len + super().__init__(sampler, batch_size, drop_last) self.batch_size = batch_size - self.group_size = group_size if group_size is not None else 100_000 - self.bin_size = bin_size if bin_size is not None else 200 - self.drop_last = drop_last + self.batch_max_len = batch_max_len + self.lengths: np.ndarray = lengths + self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 - self._efficiency = None - self._batches = None + assert isinstance(self.lengths, np.ndarray) - def efficiency(self): - if self._efficiency is None: - self._batches = self._pack_batches() - return self._efficiency + self.epoch = 0 - def _pack_batches(self): - # Get possibly shuffled indices from sampler. - sample_idxs = np.arange(len(self.sampler)) - lengths = self.lengths[sample_idxs] + # statistics + self.eff_total_used = 0 + self.eff_total_slots = 0 - pack_idxs = pack( - lengths, - self.batch_max_len, - self.group_size, - self.bin_size, + def set_epoch(self, epoch: int): + self.epoch = epoch + + def generate_batches(self, set_stats=False): + indices = [idx for idx in self.sampler] + + lengths = self.lengths[indices] + lengths_cumsum = np.cumsum(lengths) + + batches, total_used, total_slots = allocate( + lengths=lengths, + lengths_cumsum=lengths_cumsum, + rank=0, + c=self.batch_max_len, + n=1, ) - used_tokens = self.lengths.sum() - available_tokens = len(pack_idxs) * self.batch_max_len - self._efficiency = used_tokens / available_tokens - - # Wrap packs into batches. - batch_idxs = [ - pack_idxs[i : i + self.batch_size] - for i in range(0, len(pack_idxs), self.batch_size) + batches = [ + [ + [indices[b_idx] for b_idx in batch] + for batch in batches[i : i + self.batch_size] + ] + for i in range(0, len(batches), self.batch_size) ] - # Drop last batch if needed. - if self.drop_last and len(batch_idxs[-1]) < self.batch_size: - batch_idxs = batch_idxs[:-1] + # statistics + if set_stats: + self.eff_total_used += total_used + self.eff_total_slots += total_slots - return batch_idxs + return batches def __iter__(self): - self._batches = self._pack_batches() - return iter(self._batches) + batches = self.generate_batches(set_stats=True) + return iter(batches) + + def num_batches(self): + batches = self.generate_batches(set_stats=True) + return len(batches) + + def efficiency(self): + return self.eff_total_used / self.eff_total_slots def __len__(self): - if self._batches is None: - self._batches = self._pack_batches() - return len(self._batches) + self.num_batches() + return self._len_est() + + def _len_est(self): + world_size = int(os.getenv("WORLD_SIZE", "1")) + lengths_sum = np.sum(self.lengths) + lengths_sum_per_device = lengths_sum // world_size + LOG.info( + f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " + f"total_num_tokens per device: {lengths_sum_per_device}" + ) + + # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler + return max( + 0, + ( + world_size + * math.floor( + 0.99 + * lengths_sum_per_device + / self.packing_efficiency_estimate + // (self.batch_max_len * self.batch_size) + ) + - 1 + ), + )