revert multipack batch sampler changes (#1672)

* revert multipack batch sampler changes

* fix default val for drop_last
This commit is contained in:
Wing Lian
2024-05-29 11:51:18 -04:00
committed by GitHub
parent b7520801a3
commit a6b37bdeb4

View File

@@ -1,64 +1,105 @@
# pylint: skip-file
""" """
Multipack Batch Sampler Multipack Batch Sampler
""" """
import logging import logging
from concurrent.futures import ProcessPoolExecutor import math
from multiprocessing import cpu_count import os
from typing import Any, Iterable, List, Union
import numba import numba
import numpy as np import numpy as np
from torch.utils.data import BatchSampler from torch.utils.data import BatchSampler, Sampler
LOG = logging.getLogger("axolotl.utils.samplers.multipack") LOG = logging.getLogger("axolotl.utils.samplers.multipack")
# First-fit-decreasing bin packing.
@numba.njit @numba.njit
def pack_group(items, group_offset, bin_capacity, max_items_per_bin): def ffd_check(a: np.ndarray, c: int, n: int):
idxs = np.argsort(items)[::-1] # First-fit-decreasing bin packing
sorted_items = items[idxs] # Check if a[] could fit in n bins with capacity c
num_bins = len(items) # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
bins = np.full(num_bins, bin_capacity, dtype=np.int32)
bin_counts = np.zeros(num_bins, dtype=np.int32)
group_packing = np.full((num_bins, max_items_per_bin), -1, dtype=np.int32)
for idx, item in enumerate(sorted_items): a = np.sort(a)[::-1]
global_idx = idxs[idx] + group_offset bins = np.full((n,), c, dtype=a.dtype)
for size in a:
placed = False not_found = True
for i in range(num_bins): for idx in range(n):
if bins[i] >= item and bin_counts[i] < max_items_per_bin: if bins[idx] >= size:
bins[i] -= item bins[idx] -= size
group_packing[i, bin_counts[i]] = global_idx not_found = False
bin_counts[i] += 1
placed = True
break break
if not placed: if not_found:
raise ValueError( return False
f"Item could not be packed. Try increasing cfg.sample_packing_bin_size ({max_items_per_bin})."
)
return group_packing return True
def pack(items, bin_capacity, group_size, max_items_per_bin): @numba.njit
num_items = len(items) def ffd_with_result(a: np.ndarray, c: int, start_index: int):
num_processes = max(1, min(num_items // group_size, cpu_count())) # First-fit-decreasing bin packing (with result return)
tasks = [
(items[i : i + group_size], i, bin_capacity, max_items_per_bin)
for i in range(0, num_items, group_size)
]
packed_bins = [] indices = np.argsort(a)[::-1]
with ProcessPoolExecutor(max_workers=num_processes) as executor: a = a[indices]
for group_packing in executor.map(pack_group, *zip(*tasks)):
for bin_pack in group_packing:
filtered_pack = bin_pack[bin_pack != -1]
if filtered_pack.size > 0:
packed_bins.append(filtered_pack.tolist())
return packed_bins bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
return bins_result
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
start_index = 0
result = []
while True:
# binary search [l, r)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# use length l
batch = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
assert len(batch) <= n
if len(batch) < n:
break
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
return result, s, len(result) * c * n
class MultipackBatchSampler(BatchSampler): class MultipackBatchSampler(BatchSampler):
@@ -68,63 +109,95 @@ class MultipackBatchSampler(BatchSampler):
def __init__( def __init__(
self, self,
sampler, sampler: Union[Sampler[int], Iterable[int]],
lengths, batch_size: int,
batch_max_len, batch_max_len: int,
batch_size, lengths: np.ndarray,
group_size=100_000, packing_efficiency_estimate: float = 1.0,
bin_size=200, drop_last: bool = False,
drop_last=False, **kwargs,
): ):
self.sampler = sampler super().__init__(sampler, batch_size, drop_last)
self.lengths = np.array(lengths, dtype=np.int32)
self.batch_max_len = batch_max_len
self.batch_size = batch_size self.batch_size = batch_size
self.group_size = group_size if group_size is not None else 100_000 self.batch_max_len = batch_max_len
self.bin_size = bin_size if bin_size is not None else 200 self.lengths: np.ndarray = lengths
self.drop_last = drop_last self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self._efficiency = None assert isinstance(self.lengths, np.ndarray)
self._batches = None
def efficiency(self): self.epoch = 0
if self._efficiency is None:
self._batches = self._pack_batches()
return self._efficiency
def _pack_batches(self): # statistics
# Get possibly shuffled indices from sampler. self.eff_total_used = 0
sample_idxs = np.arange(len(self.sampler)) self.eff_total_slots = 0
lengths = self.lengths[sample_idxs]
pack_idxs = pack( def set_epoch(self, epoch: int):
lengths, self.epoch = epoch
self.batch_max_len,
self.group_size, def generate_batches(self, set_stats=False):
self.bin_size, indices = [idx for idx in self.sampler]
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
) )
used_tokens = self.lengths.sum() batches = [
available_tokens = len(pack_idxs) * self.batch_max_len [
self._efficiency = used_tokens / available_tokens [indices[b_idx] for b_idx in batch]
for batch in batches[i : i + self.batch_size]
# Wrap packs into batches. ]
batch_idxs = [ for i in range(0, len(batches), self.batch_size)
pack_idxs[i : i + self.batch_size]
for i in range(0, len(pack_idxs), self.batch_size)
] ]
# Drop last batch if needed. # statistics
if self.drop_last and len(batch_idxs[-1]) < self.batch_size: if set_stats:
batch_idxs = batch_idxs[:-1] self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batch_idxs return batches
def __iter__(self): def __iter__(self):
self._batches = self._pack_batches() batches = self.generate_batches(set_stats=True)
return iter(self._batches) return iter(batches)
def num_batches(self):
batches = self.generate_batches(set_stats=True)
return len(batches)
def efficiency(self):
return self.eff_total_used / self.eff_total_slots
def __len__(self): def __len__(self):
if self._batches is None: self.num_batches()
self._batches = self._pack_batches() return self._len_est()
return len(self._batches)
def _len_est(self):
world_size = int(os.getenv("WORLD_SIZE", "1"))
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // world_size
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return max(
0,
(
world_size
* math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
// (self.batch_max_len * self.batch_size)
)
- 1
),
)