parallel bin packing

fix error with lambda and pickling

make sure things are in float instead of np.float
This commit is contained in:
Wing Lian
2025-05-04 18:12:09 -04:00
parent 03508c6816
commit 5b2bd75aba

View File

@@ -1,9 +1,12 @@
# pylint: skip-file
"""
Multipack Batch Sampler
Multipack Batch Sampler - An efficient batch sampler for packing variable-length sequences
into fixed-capacity batches to optimize memory usage and training throughput.
"""
import logging
import math
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from typing import Iterable, List, Union
import numba
@@ -13,17 +16,24 @@ from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from axolotl.utils.distributed import reduce_and_broadcast
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
@numba.njit
def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
# First-fit-decreasing bin packing algorithm
# Checks if sequences with lengths in sequence_lengths[] could fit in num_bins bins, each with capacity bin_capacity
# Returns True if all sequences can be packed, False otherwise
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
"""
First-fit-decreasing bin packing algorithm check
Checks if sequences with the given lengths could fit in the specified number of bins
Args:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin
num_bins: Number of bins available
Returns:
True if all sequences can be packed, False otherwise
"""
# Sort sequence lengths in descending order for optimal packing
sequence_lengths = np.sort(sequence_lengths)[::-1]
# Initialize all bins with full capacity
@@ -46,98 +56,104 @@ def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
@numba.njit
def ffd_with_result(sequence_lengths: np.ndarray, bin_capacity: int, start_index: int):
# First-fit-decreasing bin packing that returns the actual bin assignments
# Returns a list of bins, where each bin contains indices of sequences assigned to it
def pack_group(
sequence_lengths: np.ndarray,
group_offset: int,
bin_capacity: int,
max_bins: int,
safe_mode: bool = False,
):
"""
Pack a group of sequences into bins using First-Fit Decreasing algorithm
# Get sorting indices and sort sequence lengths in descending order
Args:
sequence_lengths: Array of sequence lengths
group_offset: Offset to apply to indices when returning results
bin_capacity: Maximum capacity of each bin
max_bins: Maximum number of bins to use
safe_mode: If True, use a more conservative packing approach
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
# Get sorting indices and sort lengths in descending order
indices = np.argsort(sequence_lengths)[::-1]
sequence_lengths = sequence_lengths[indices]
sorted_lengths = sequence_lengths[indices]
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
# Place each sequence in the first bin it fits
for seq_id, size in enumerate(sequence_lengths):
for seq_id, size in enumerate(sorted_lengths):
global_idx = indices[seq_id] + group_offset
# Try to place sequence in existing bins
add_new_bin = True
for bin_idx in range(len(bins_remaining_space)):
for bin_idx, _ in enumerate(bins_remaining_space):
if bins_remaining_space[bin_idx] >= size:
bins_remaining_space[bin_idx] -= size
bins_assigned_sequences[bin_idx].append(indices[seq_id] + start_index)
bins_assigned_sequences[bin_idx].append(global_idx)
add_new_bin = False
break
# If no existing bin could fit this sequence, create a new bin
# Create a new bin if needed and if we haven't reached the limit
if add_new_bin:
if len(bins_remaining_space) >= max_bins and safe_mode:
# In safe mode, skip items that would exceed max_bins
continue
bins_remaining_space.append(bin_capacity - size)
bins_assigned_sequences.append([indices[seq_id] + start_index])
bins_assigned_sequences.append([global_idx])
# Safety check to avoid infinite bins
if len(bins_remaining_space) > len(sequence_lengths):
break
return bins_assigned_sequences
@numba.njit
def allocate(
# Define a standalone function for multiprocessing
def _process_group(args):
group_lengths, start_idx, bin_capacity, max_bins, safe_mode = args
return pack_group(group_lengths, start_idx, bin_capacity, max_bins, safe_mode)
def pack_parallel(
sequence_lengths: np.ndarray,
lengths_cumsum: np.ndarray,
rank: int,
bin_capacity: int,
num_ranks: int,
group_size: int,
num_processes: int | None = None,
safe_mode: bool = True,
):
# Dynamic batch allocator, similar to Multifit algorithm
# Efficiently packs sequences into fixed-capacity bins for distributed training
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
"""
Pack sequences into bins using parallel processing
total_processed_tokens = 0
start_index = 0
rank_batches = [] # Batches assigned to the current rank
Args:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin
group_size: Number of sequences to process in each group
num_processes: Number of parallel processes to use
safe_mode: If True, use a more conservative packing approach
while True:
# Binary search to find maximum number of sequences that can be packed into num_ranks bins
# [left, right) defines the search range
left = 1
right = 1 + np.searchsorted(
lengths_cumsum[start_index:],
total_processed_tokens + bin_capacity * num_ranks,
"right",
)
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
num_items = len(sequence_lengths)
if num_processes is None:
num_processes = max(1, min(num_items // group_size, cpu_count()))
while right - left > 1:
mid = (left + right) // 2
if ffd_check(
sequence_lengths[start_index : start_index + mid],
bin_capacity,
num_ranks,
):
left = mid
else:
right = mid
# Create tasks for parallel processing
tasks = []
for i in range(0, num_items, group_size):
group_lengths = sequence_lengths[i : i + group_size]
max_bins = len(group_lengths) # Allow as many bins as items in the group
tasks.append((group_lengths, i, bin_capacity, max_bins, safe_mode))
# Pack the identified sequences into bins
all_rank_batches = ffd_with_result(
sequence_lengths[start_index : start_index + left],
bin_capacity,
start_index,
)
assert len(all_rank_batches) <= num_ranks
# Process groups in parallel
all_bins = []
with ProcessPoolExecutor(max_workers=num_processes) as executor:
for group_bins in executor.map(_process_group, tasks):
all_bins.extend(group_bins)
# If we couldn't fill all ranks, we're done
if len(all_rank_batches) < num_ranks:
break
# Update indices and processed token count
start_index += left
total_processed_tokens = lengths_cumsum[start_index - 1]
# Add the batch for the current rank
rank_batches.append(all_rank_batches[rank])
# Return batches for this rank, total tokens used, and total token slots available
return (
rank_batches,
total_processed_tokens,
len(rank_batches) * bin_capacity * num_ranks,
)
return all_bins
@numba.njit
@@ -145,25 +161,25 @@ def allocate_sequentially(
sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int
):
"""
Sequential allocator that preserves example order (no sorting by length)
Sequential allocator that preserves example order
Arguments:
- sequence_lengths: The lengths of all examples
- rank: The current rank (for distributed training)
- bin_capacity: The capacity of each bin (maximum sequence length)
- num_ranks: Number of ranks (processes/GPUs)
Parameters:
sequence_lengths: The lengths of all examples
rank: The current rank (for distributed training)
bin_capacity: The capacity of each bin (maximum sequence length)
num_ranks: Number of ranks (processes/GPUs)
Returns:
- rank_batches: List of batches for the current rank
- total_tokens_used: Number of actual example tokens
- total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
rank_batches: List of batches for the current rank
total_tokens_used: Number of actual example tokens
total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
"""
rank_batches = []
total_tokens_used = 0
# First, do sequential packing into bins
all_bins = []
current_bin = [0 for i in range(0)] # numba hint for empty list of integers
current_bin = []
remaining_capacity = bin_capacity
# Process each sequence in order
@@ -194,12 +210,12 @@ def allocate_sequentially(
class MultipackBatchSampler(BatchSampler):
"""
Batch sampler class for efficient packing of variable-length sequences.
Batch sampler class for efficient packing of variable-length sequences
This sampler packs sequences into fixed-capacity bins (batches) to maximize
GPU memory utilization and training throughput by reducing padding.
It supports both length-optimized packing (using FFD algorithm) and
It supports both parallel packing (using FFD algorithm) and
sequential packing (preserving original sequence order).
"""
@@ -212,29 +228,38 @@ class MultipackBatchSampler(BatchSampler):
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = False, # Whether to drop incomplete batches
num_count_samples: int = 16, # Number of samples to estimate batch count
sequential: bool = False, # Whether to use sequential packing instead of FFD
**kwargs,
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
num_processes: int | None = None, # Number of processes for parallel packing
safe_mode: bool = True, # Conservative packing to prevent training instability
**kwargs, # pylint: disable=unused-argument
):
super().__init__(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths
self.lengths = np.array(lengths, dtype=np.int32)
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.sequential = sequential
self.group_size = group_size
self.num_processes = num_processes
self.safe_mode = safe_mode
assert isinstance(self.lengths, np.ndarray)
self.epoch = 0
# Efficiency statistics tracking
self.eff_total_used = 0 # Total tokens used
self.eff_total_slots = 0 # Total token slots available
self.total_tokens_used = 0
self.total_token_slots = 0
# The number of times to calculate batches to determine minimum packed dataset length
self.num_count_samples = num_count_samples
# Minimum packed dataset length across all ranks (determined by gather/broadcast)
self.len_across_ranks = None
# Cache for batches
self._batches = None
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warn(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
@@ -243,6 +268,7 @@ class MultipackBatchSampler(BatchSampler):
def set_epoch(self, epoch: int):
"""Set the epoch number, used for reproducible shuffling across epochs"""
self.epoch = epoch
self._batches = None # Invalidate batch cache
def generate_batches(self, set_stats=False):
"""
@@ -255,44 +281,62 @@ class MultipackBatchSampler(BatchSampler):
List of batches, where each batch contains multiple bins,
and each bin contains multiple sequence indices
"""
if self._batches is not None:
return self._batches
# Get indices from the sampler
indices = [idx for idx in self.sampler]
indices = [ # pylint: disable=unnecessary-comprehension
idx for idx in self.sampler
]
# Get lengths of the selected sequences
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
# Pack sequences into bins using either sequential or FFD allocation
# Pack sequences into bins using either sequential or parallel packing
if self.sequential:
bins, total_used, total_slots = allocate_sequentially(
lengths=lengths,
lengths,
rank=0,
bin_capacity=self.batch_max_len,
num_ranks=1,
)
else:
bins, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
# Use parallel packing
all_bins = pack_parallel(
lengths,
bin_capacity=self.batch_max_len,
num_ranks=1,
group_size=self.group_size,
num_processes=self.num_processes,
safe_mode=self.safe_mode,
)
# Map bin indices back to original indices
bins = [
[indices[b_idx] for b_idx in bin_indices] for bin_indices in all_bins
]
# Calculate efficiency statistics
total_used = lengths.sum()
total_slots = len(all_bins) * self.batch_max_len
# Group bins into batches (each batch contains batch_size bins)
batches = [
[
[indices[b_idx] for b_idx in bin_indices]
for bin_indices in bins[i : i + self.batch_size]
]
for i in range(0, len(bins), self.batch_size)
bins[i : i + self.batch_size] for i in range(0, len(bins), self.batch_size)
]
# Drop last batch if requested and it's incomplete
if self.drop_last and len(batches[-1]) < self.batch_size:
batches = batches[:-1]
# Adjust total_slots if we dropped a batch
if not self.sequential:
total_slots -= (self.batch_size - len(batches[-1])) * self.batch_max_len
# Update statistics if requested
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
self.total_tokens_used += total_used
self.total_token_slots += total_slots
self._batches = batches
return batches
def __iter__(self):
@@ -308,17 +352,17 @@ class MultipackBatchSampler(BatchSampler):
batches = batches[: self.len_across_ranks]
return iter(batches)
def num_batches(self):
"""Calculate the number of batches for this rank"""
batches = self.generate_batches(set_stats=True)
return len(batches)
def efficiency(self):
"""
Calculate the packing efficiency (ratio of tokens used to total token slots)
Higher is better - 1.0 would mean perfect packing with no wasted space
"""
return self.eff_total_used / self.eff_total_slots
if self.total_token_slots == 0:
self.generate_batches(set_stats=True)
if self.total_token_slots == 0:
return 0.0
# Return a Python float instead of potentially a numpy float
return float(self.total_tokens_used / self.total_token_slots)
def gather_efficiency(self):
"""
@@ -329,11 +373,12 @@ class MultipackBatchSampler(BatchSampler):
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
# Use 99.7% of max observed efficiency as a safe estimate
return math.floor(0.997 * max(estimates))
max_eff = max(float(eff) for eff in estimates)
return math.floor(0.997 * max_eff)
# Gather efficiency from all ranks and apply the calculation function
sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
lambda: float(self.efficiency()), # pylint: disable=unnecessary-lambda
calc_sample_packing_eff_est,
)
@@ -364,11 +409,15 @@ class MultipackBatchSampler(BatchSampler):
This is calculated as the minimum number of batches available on any rank
to ensure balanced distributed training
"""
if not self.len_across_ranks:
if self._batches is None:
self._batches = self.generate_batches(set_stats=True)
if self.len_across_ranks is None:
# Sample multiple times to get stable estimate
len_batches = min(
[self.num_batches() for _ in range(self.num_count_samples)]
len_batches = min( # pylint: disable=consider-using-generator
[len(self._batches) for _ in range(self.num_count_samples)]
)
# Gather minimum across all ranks
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks