parallel bin packing

fix error with lambda and pickling

make sure things are in float instead of np.float
This commit is contained in:
Wing Lian
2025-05-04 18:12:09 -04:00
parent 03508c6816
commit 5b2bd75aba

View File

@@ -1,9 +1,12 @@
# pylint: skip-file
""" """
Multipack Batch Sampler Multipack Batch Sampler - An efficient batch sampler for packing variable-length sequences
into fixed-capacity batches to optimize memory usage and training throughput.
""" """
import logging import logging
import math import math
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from typing import Iterable, List, Union from typing import Iterable, List, Union
import numba import numba
@@ -13,17 +16,24 @@ from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.distributed import reduce_and_broadcast
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO) LOG.setLevel(logging.INFO)
@numba.njit @numba.njit
def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int): def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
# First-fit-decreasing bin packing algorithm """
# Checks if sequences with lengths in sequence_lengths[] could fit in num_bins bins, each with capacity bin_capacity First-fit-decreasing bin packing algorithm check
# Returns True if all sequences can be packed, False otherwise
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
Checks if sequences with the given lengths could fit in the specified number of bins
Args:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin
num_bins: Number of bins available
Returns:
True if all sequences can be packed, False otherwise
"""
# Sort sequence lengths in descending order for optimal packing # Sort sequence lengths in descending order for optimal packing
sequence_lengths = np.sort(sequence_lengths)[::-1] sequence_lengths = np.sort(sequence_lengths)[::-1]
# Initialize all bins with full capacity # Initialize all bins with full capacity
@@ -46,98 +56,104 @@ def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
@numba.njit @numba.njit
def ffd_with_result(sequence_lengths: np.ndarray, bin_capacity: int, start_index: int): def pack_group(
# First-fit-decreasing bin packing that returns the actual bin assignments sequence_lengths: np.ndarray,
# Returns a list of bins, where each bin contains indices of sequences assigned to it group_offset: int,
bin_capacity: int,
max_bins: int,
safe_mode: bool = False,
):
"""
Pack a group of sequences into bins using First-Fit Decreasing algorithm
# Get sorting indices and sort sequence lengths in descending order Args:
sequence_lengths: Array of sequence lengths
group_offset: Offset to apply to indices when returning results
bin_capacity: Maximum capacity of each bin
max_bins: Maximum number of bins to use
safe_mode: If True, use a more conservative packing approach
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
# Get sorting indices and sort lengths in descending order
indices = np.argsort(sequence_lengths)[::-1] indices = np.argsort(sequence_lengths)[::-1]
sequence_lengths = sequence_lengths[indices] sorted_lengths = sequence_lengths[indices]
bins_remaining_space: list = [] # Tracks remaining capacity in each bin bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
# Place each sequence in the first bin it fits for seq_id, size in enumerate(sorted_lengths):
for seq_id, size in enumerate(sequence_lengths): global_idx = indices[seq_id] + group_offset
# Try to place sequence in existing bins
add_new_bin = True add_new_bin = True
for bin_idx in range(len(bins_remaining_space)): for bin_idx, _ in enumerate(bins_remaining_space):
if bins_remaining_space[bin_idx] >= size: if bins_remaining_space[bin_idx] >= size:
bins_remaining_space[bin_idx] -= size bins_remaining_space[bin_idx] -= size
bins_assigned_sequences[bin_idx].append(indices[seq_id] + start_index) bins_assigned_sequences[bin_idx].append(global_idx)
add_new_bin = False add_new_bin = False
break break
# If no existing bin could fit this sequence, create a new bin # Create a new bin if needed and if we haven't reached the limit
if add_new_bin: if add_new_bin:
if len(bins_remaining_space) >= max_bins and safe_mode:
# In safe mode, skip items that would exceed max_bins
continue
bins_remaining_space.append(bin_capacity - size) bins_remaining_space.append(bin_capacity - size)
bins_assigned_sequences.append([indices[seq_id] + start_index]) bins_assigned_sequences.append([global_idx])
# Safety check to avoid infinite bins
if len(bins_remaining_space) > len(sequence_lengths):
break
return bins_assigned_sequences return bins_assigned_sequences
@numba.njit # Define a standalone function for multiprocessing
def allocate( def _process_group(args):
group_lengths, start_idx, bin_capacity, max_bins, safe_mode = args
return pack_group(group_lengths, start_idx, bin_capacity, max_bins, safe_mode)
def pack_parallel(
sequence_lengths: np.ndarray, sequence_lengths: np.ndarray,
lengths_cumsum: np.ndarray,
rank: int,
bin_capacity: int, bin_capacity: int,
num_ranks: int, group_size: int,
num_processes: int | None = None,
safe_mode: bool = True,
): ):
# Dynamic batch allocator, similar to Multifit algorithm """
# Efficiently packs sequences into fixed-capacity bins for distributed training Pack sequences into bins using parallel processing
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
total_processed_tokens = 0 Args:
start_index = 0 sequence_lengths: Array of sequence lengths
rank_batches = [] # Batches assigned to the current rank bin_capacity: Maximum capacity of each bin
group_size: Number of sequences to process in each group
num_processes: Number of parallel processes to use
safe_mode: If True, use a more conservative packing approach
while True: Returns:
# Binary search to find maximum number of sequences that can be packed into num_ranks bins List of bins, where each bin contains indices of sequences assigned to it
# [left, right) defines the search range """
left = 1 num_items = len(sequence_lengths)
right = 1 + np.searchsorted( if num_processes is None:
lengths_cumsum[start_index:], num_processes = max(1, min(num_items // group_size, cpu_count()))
total_processed_tokens + bin_capacity * num_ranks,
"right",
)
while right - left > 1: # Create tasks for parallel processing
mid = (left + right) // 2 tasks = []
if ffd_check( for i in range(0, num_items, group_size):
sequence_lengths[start_index : start_index + mid], group_lengths = sequence_lengths[i : i + group_size]
bin_capacity, max_bins = len(group_lengths) # Allow as many bins as items in the group
num_ranks, tasks.append((group_lengths, i, bin_capacity, max_bins, safe_mode))
):
left = mid
else:
right = mid
# Pack the identified sequences into bins # Process groups in parallel
all_rank_batches = ffd_with_result( all_bins = []
sequence_lengths[start_index : start_index + left], with ProcessPoolExecutor(max_workers=num_processes) as executor:
bin_capacity, for group_bins in executor.map(_process_group, tasks):
start_index, all_bins.extend(group_bins)
)
assert len(all_rank_batches) <= num_ranks
# If we couldn't fill all ranks, we're done return all_bins
if len(all_rank_batches) < num_ranks:
break
# Update indices and processed token count
start_index += left
total_processed_tokens = lengths_cumsum[start_index - 1]
# Add the batch for the current rank
rank_batches.append(all_rank_batches[rank])
# Return batches for this rank, total tokens used, and total token slots available
return (
rank_batches,
total_processed_tokens,
len(rank_batches) * bin_capacity * num_ranks,
)
@numba.njit @numba.njit
@@ -145,25 +161,25 @@ def allocate_sequentially(
sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int
): ):
""" """
Sequential allocator that preserves example order (no sorting by length) Sequential allocator that preserves example order
Arguments: Parameters:
- sequence_lengths: The lengths of all examples sequence_lengths: The lengths of all examples
- rank: The current rank (for distributed training) rank: The current rank (for distributed training)
- bin_capacity: The capacity of each bin (maximum sequence length) bin_capacity: The capacity of each bin (maximum sequence length)
- num_ranks: Number of ranks (processes/GPUs) num_ranks: Number of ranks (processes/GPUs)
Returns: Returns:
- rank_batches: List of batches for the current rank rank_batches: List of batches for the current rank
- total_tokens_used: Number of actual example tokens total_tokens_used: Number of actual example tokens
- total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity) total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
""" """
rank_batches = [] rank_batches = []
total_tokens_used = 0 total_tokens_used = 0
# First, do sequential packing into bins # First, do sequential packing into bins
all_bins = [] all_bins = []
current_bin = [0 for i in range(0)] # numba hint for empty list of integers current_bin = []
remaining_capacity = bin_capacity remaining_capacity = bin_capacity
# Process each sequence in order # Process each sequence in order
@@ -194,12 +210,12 @@ def allocate_sequentially(
class MultipackBatchSampler(BatchSampler): class MultipackBatchSampler(BatchSampler):
""" """
Batch sampler class for efficient packing of variable-length sequences. Batch sampler class for efficient packing of variable-length sequences
This sampler packs sequences into fixed-capacity bins (batches) to maximize This sampler packs sequences into fixed-capacity bins (batches) to maximize
GPU memory utilization and training throughput by reducing padding. GPU memory utilization and training throughput by reducing padding.
It supports both length-optimized packing (using FFD algorithm) and It supports both parallel packing (using FFD algorithm) and
sequential packing (preserving original sequence order). sequential packing (preserving original sequence order).
""" """
@@ -212,29 +228,38 @@ class MultipackBatchSampler(BatchSampler):
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = False, # Whether to drop incomplete batches drop_last: bool = False, # Whether to drop incomplete batches
num_count_samples: int = 16, # Number of samples to estimate batch count num_count_samples: int = 16, # Number of samples to estimate batch count
sequential: bool = False, # Whether to use sequential packing instead of FFD sequential: bool = False, # Whether to use sequential packing
**kwargs, group_size: int = 100_000, # Size of groups for parallel packing
num_processes: int | None = None, # Number of processes for parallel packing
safe_mode: bool = True, # Conservative packing to prevent training instability
**kwargs, # pylint: disable=unused-argument
): ):
super().__init__(sampler, batch_size, drop_last) super().__init__(sampler, batch_size, drop_last)
self.batch_size = batch_size self.batch_size = batch_size
self.batch_max_len = batch_max_len self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths self.lengths = np.array(lengths, dtype=np.int32)
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.sequential = sequential self.sequential = sequential
self.group_size = group_size
self.num_processes = num_processes
self.safe_mode = safe_mode
assert isinstance(self.lengths, np.ndarray) assert isinstance(self.lengths, np.ndarray)
self.epoch = 0 self.epoch = 0
# Efficiency statistics tracking # Efficiency statistics tracking
self.eff_total_used = 0 # Total tokens used self.total_tokens_used = 0
self.eff_total_slots = 0 # Total token slots available self.total_token_slots = 0
# The number of times to calculate batches to determine minimum packed dataset length # The number of times to calculate batches to determine minimum packed dataset length
self.num_count_samples = num_count_samples self.num_count_samples = num_count_samples
# Minimum packed dataset length across all ranks (determined by gather/broadcast) # Minimum packed dataset length across all ranks (determined by gather/broadcast)
self.len_across_ranks = None self.len_across_ranks = None
# Cache for batches
self._batches = None
if self.sequential and not isinstance(sampler, SequentialSampler): if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warn( LOG.warn(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?" "using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
@@ -243,6 +268,7 @@ class MultipackBatchSampler(BatchSampler):
def set_epoch(self, epoch: int): def set_epoch(self, epoch: int):
"""Set the epoch number, used for reproducible shuffling across epochs""" """Set the epoch number, used for reproducible shuffling across epochs"""
self.epoch = epoch self.epoch = epoch
self._batches = None # Invalidate batch cache
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats=False):
""" """
@@ -255,44 +281,62 @@ class MultipackBatchSampler(BatchSampler):
List of batches, where each batch contains multiple bins, List of batches, where each batch contains multiple bins,
and each bin contains multiple sequence indices and each bin contains multiple sequence indices
""" """
if self._batches is not None:
return self._batches
# Get indices from the sampler # Get indices from the sampler
indices = [idx for idx in self.sampler] indices = [ # pylint: disable=unnecessary-comprehension
idx for idx in self.sampler
]
# Get lengths of the selected sequences # Get lengths of the selected sequences
lengths = self.lengths[indices] lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
# Pack sequences into bins using either sequential or FFD allocation # Pack sequences into bins using either sequential or parallel packing
if self.sequential: if self.sequential:
bins, total_used, total_slots = allocate_sequentially( bins, total_used, total_slots = allocate_sequentially(
lengths=lengths, lengths,
rank=0, rank=0,
bin_capacity=self.batch_max_len, bin_capacity=self.batch_max_len,
num_ranks=1, num_ranks=1,
) )
else: else:
bins, total_used, total_slots = allocate( # Use parallel packing
lengths=lengths, all_bins = pack_parallel(
lengths_cumsum=lengths_cumsum, lengths,
rank=0,
bin_capacity=self.batch_max_len, bin_capacity=self.batch_max_len,
num_ranks=1, group_size=self.group_size,
num_processes=self.num_processes,
safe_mode=self.safe_mode,
) )
# Map bin indices back to original indices
bins = [
[indices[b_idx] for b_idx in bin_indices] for bin_indices in all_bins
]
# Calculate efficiency statistics
total_used = lengths.sum()
total_slots = len(all_bins) * self.batch_max_len
# Group bins into batches (each batch contains batch_size bins) # Group bins into batches (each batch contains batch_size bins)
batches = [ batches = [
[ bins[i : i + self.batch_size] for i in range(0, len(bins), self.batch_size)
[indices[b_idx] for b_idx in bin_indices]
for bin_indices in bins[i : i + self.batch_size]
]
for i in range(0, len(bins), self.batch_size)
] ]
# Drop last batch if requested and it's incomplete
if self.drop_last and len(batches[-1]) < self.batch_size:
batches = batches[:-1]
# Adjust total_slots if we dropped a batch
if not self.sequential:
total_slots -= (self.batch_size - len(batches[-1])) * self.batch_max_len
# Update statistics if requested # Update statistics if requested
if set_stats: if set_stats:
self.eff_total_used += total_used self.total_tokens_used += total_used
self.eff_total_slots += total_slots self.total_token_slots += total_slots
self._batches = batches
return batches return batches
def __iter__(self): def __iter__(self):
@@ -308,17 +352,17 @@ class MultipackBatchSampler(BatchSampler):
batches = batches[: self.len_across_ranks] batches = batches[: self.len_across_ranks]
return iter(batches) return iter(batches)
def num_batches(self):
"""Calculate the number of batches for this rank"""
batches = self.generate_batches(set_stats=True)
return len(batches)
def efficiency(self): def efficiency(self):
""" """
Calculate the packing efficiency (ratio of tokens used to total token slots) Calculate the packing efficiency (ratio of tokens used to total token slots)
Higher is better - 1.0 would mean perfect packing with no wasted space Higher is better - 1.0 would mean perfect packing with no wasted space
""" """
return self.eff_total_used / self.eff_total_slots if self.total_token_slots == 0:
self.generate_batches(set_stats=True)
if self.total_token_slots == 0:
return 0.0
# Return a Python float instead of potentially a numpy float
return float(self.total_tokens_used / self.total_token_slots)
def gather_efficiency(self): def gather_efficiency(self):
""" """
@@ -329,11 +373,12 @@ class MultipackBatchSampler(BatchSampler):
def calc_sample_packing_eff_est(estimates: List[float]): def calc_sample_packing_eff_est(estimates: List[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}") LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
# Use 99.7% of max observed efficiency as a safe estimate # Use 99.7% of max observed efficiency as a safe estimate
return math.floor(0.997 * max(estimates)) max_eff = max(float(eff) for eff in estimates)
return math.floor(0.997 * max_eff)
# Gather efficiency from all ranks and apply the calculation function # Gather efficiency from all ranks and apply the calculation function
sample_packing_actual_eff_all = reduce_and_broadcast( sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda lambda: float(self.efficiency()), # pylint: disable=unnecessary-lambda
calc_sample_packing_eff_est, calc_sample_packing_eff_est,
) )
@@ -364,11 +409,15 @@ class MultipackBatchSampler(BatchSampler):
This is calculated as the minimum number of batches available on any rank This is calculated as the minimum number of batches available on any rank
to ensure balanced distributed training to ensure balanced distributed training
""" """
if not self.len_across_ranks: if self._batches is None:
self._batches = self.generate_batches(set_stats=True)
if self.len_across_ranks is None:
# Sample multiple times to get stable estimate # Sample multiple times to get stable estimate
len_batches = min( len_batches = min( # pylint: disable=consider-using-generator
[self.num_batches() for _ in range(self.num_count_samples)] [len(self._batches) for _ in range(self.num_count_samples)]
) )
# Gather minimum across all ranks # Gather minimum across all ranks
self.len_across_ranks = self.gather_len_batches(len_batches) self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks return self.len_across_ranks