From 5b2bd75aba378739ee56a6bb6ac679a085a0be05 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 4 May 2025 18:12:09 -0400 Subject: [PATCH] parallel bin packing fix error with lambda and pickling make sure things are in float instead of np.float --- src/axolotl/utils/samplers/multipack.py | 283 ++++++++++++++---------- 1 file changed, 166 insertions(+), 117 deletions(-) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index e419486e2..1abe594e3 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -1,9 +1,12 @@ -# pylint: skip-file """ -Multipack Batch Sampler +Multipack Batch Sampler - An efficient batch sampler for packing variable-length sequences +into fixed-capacity batches to optimize memory usage and training throughput. """ + import logging import math +from concurrent.futures import ProcessPoolExecutor +from multiprocessing import cpu_count from typing import Iterable, List, Union import numba @@ -13,17 +16,24 @@ from torch.utils.data import BatchSampler, Sampler, SequentialSampler from axolotl.utils.distributed import reduce_and_broadcast LOG = logging.getLogger(__name__) - LOG.setLevel(logging.INFO) @numba.njit def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int): - # First-fit-decreasing bin packing algorithm - # Checks if sequences with lengths in sequence_lengths[] could fit in num_bins bins, each with capacity bin_capacity - # Returns True if all sequences can be packed, False otherwise - # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing + """ + First-fit-decreasing bin packing algorithm check + Checks if sequences with the given lengths could fit in the specified number of bins + + Args: + sequence_lengths: Array of sequence lengths + bin_capacity: Maximum capacity of each bin + num_bins: Number of bins available + + Returns: + True if all sequences can be packed, False otherwise + """ # Sort sequence lengths in descending order for optimal packing sequence_lengths = np.sort(sequence_lengths)[::-1] # Initialize all bins with full capacity @@ -46,98 +56,104 @@ def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int): @numba.njit -def ffd_with_result(sequence_lengths: np.ndarray, bin_capacity: int, start_index: int): - # First-fit-decreasing bin packing that returns the actual bin assignments - # Returns a list of bins, where each bin contains indices of sequences assigned to it +def pack_group( + sequence_lengths: np.ndarray, + group_offset: int, + bin_capacity: int, + max_bins: int, + safe_mode: bool = False, +): + """ + Pack a group of sequences into bins using First-Fit Decreasing algorithm - # Get sorting indices and sort sequence lengths in descending order + Args: + sequence_lengths: Array of sequence lengths + group_offset: Offset to apply to indices when returning results + bin_capacity: Maximum capacity of each bin + max_bins: Maximum number of bins to use + safe_mode: If True, use a more conservative packing approach + + Returns: + List of bins, where each bin contains indices of sequences assigned to it + """ + # Get sorting indices and sort lengths in descending order indices = np.argsort(sequence_lengths)[::-1] - sequence_lengths = sequence_lengths[indices] + sorted_lengths = sequence_lengths[indices] bins_remaining_space: list = [] # Tracks remaining capacity in each bin bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin - # Place each sequence in the first bin it fits - for seq_id, size in enumerate(sequence_lengths): + for seq_id, size in enumerate(sorted_lengths): + global_idx = indices[seq_id] + group_offset + + # Try to place sequence in existing bins add_new_bin = True - for bin_idx in range(len(bins_remaining_space)): + for bin_idx, _ in enumerate(bins_remaining_space): if bins_remaining_space[bin_idx] >= size: bins_remaining_space[bin_idx] -= size - bins_assigned_sequences[bin_idx].append(indices[seq_id] + start_index) + bins_assigned_sequences[bin_idx].append(global_idx) add_new_bin = False break - # If no existing bin could fit this sequence, create a new bin + # Create a new bin if needed and if we haven't reached the limit if add_new_bin: + if len(bins_remaining_space) >= max_bins and safe_mode: + # In safe mode, skip items that would exceed max_bins + continue bins_remaining_space.append(bin_capacity - size) - bins_assigned_sequences.append([indices[seq_id] + start_index]) + bins_assigned_sequences.append([global_idx]) + + # Safety check to avoid infinite bins + if len(bins_remaining_space) > len(sequence_lengths): + break return bins_assigned_sequences -@numba.njit -def allocate( +# Define a standalone function for multiprocessing +def _process_group(args): + group_lengths, start_idx, bin_capacity, max_bins, safe_mode = args + return pack_group(group_lengths, start_idx, bin_capacity, max_bins, safe_mode) + + +def pack_parallel( sequence_lengths: np.ndarray, - lengths_cumsum: np.ndarray, - rank: int, bin_capacity: int, - num_ranks: int, + group_size: int, + num_processes: int | None = None, + safe_mode: bool = True, ): - # Dynamic batch allocator, similar to Multifit algorithm - # Efficiently packs sequences into fixed-capacity bins for distributed training - # https://en.wikipedia.org/wiki/Multifit_algorithm - # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) + """ + Pack sequences into bins using parallel processing - total_processed_tokens = 0 - start_index = 0 - rank_batches = [] # Batches assigned to the current rank + Args: + sequence_lengths: Array of sequence lengths + bin_capacity: Maximum capacity of each bin + group_size: Number of sequences to process in each group + num_processes: Number of parallel processes to use + safe_mode: If True, use a more conservative packing approach - while True: - # Binary search to find maximum number of sequences that can be packed into num_ranks bins - # [left, right) defines the search range - left = 1 - right = 1 + np.searchsorted( - lengths_cumsum[start_index:], - total_processed_tokens + bin_capacity * num_ranks, - "right", - ) + Returns: + List of bins, where each bin contains indices of sequences assigned to it + """ + num_items = len(sequence_lengths) + if num_processes is None: + num_processes = max(1, min(num_items // group_size, cpu_count())) - while right - left > 1: - mid = (left + right) // 2 - if ffd_check( - sequence_lengths[start_index : start_index + mid], - bin_capacity, - num_ranks, - ): - left = mid - else: - right = mid + # Create tasks for parallel processing + tasks = [] + for i in range(0, num_items, group_size): + group_lengths = sequence_lengths[i : i + group_size] + max_bins = len(group_lengths) # Allow as many bins as items in the group + tasks.append((group_lengths, i, bin_capacity, max_bins, safe_mode)) - # Pack the identified sequences into bins - all_rank_batches = ffd_with_result( - sequence_lengths[start_index : start_index + left], - bin_capacity, - start_index, - ) - assert len(all_rank_batches) <= num_ranks + # Process groups in parallel + all_bins = [] + with ProcessPoolExecutor(max_workers=num_processes) as executor: + for group_bins in executor.map(_process_group, tasks): + all_bins.extend(group_bins) - # If we couldn't fill all ranks, we're done - if len(all_rank_batches) < num_ranks: - break - - # Update indices and processed token count - start_index += left - total_processed_tokens = lengths_cumsum[start_index - 1] - - # Add the batch for the current rank - rank_batches.append(all_rank_batches[rank]) - - # Return batches for this rank, total tokens used, and total token slots available - return ( - rank_batches, - total_processed_tokens, - len(rank_batches) * bin_capacity * num_ranks, - ) + return all_bins @numba.njit @@ -145,25 +161,25 @@ def allocate_sequentially( sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int ): """ - Sequential allocator that preserves example order (no sorting by length) + Sequential allocator that preserves example order - Arguments: - - sequence_lengths: The lengths of all examples - - rank: The current rank (for distributed training) - - bin_capacity: The capacity of each bin (maximum sequence length) - - num_ranks: Number of ranks (processes/GPUs) + Parameters: + sequence_lengths: The lengths of all examples + rank: The current rank (for distributed training) + bin_capacity: The capacity of each bin (maximum sequence length) + num_ranks: Number of ranks (processes/GPUs) Returns: - - rank_batches: List of batches for the current rank - - total_tokens_used: Number of actual example tokens - - total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity) + rank_batches: List of batches for the current rank + total_tokens_used: Number of actual example tokens + total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity) """ rank_batches = [] total_tokens_used = 0 # First, do sequential packing into bins all_bins = [] - current_bin = [0 for i in range(0)] # numba hint for empty list of integers + current_bin = [] remaining_capacity = bin_capacity # Process each sequence in order @@ -194,12 +210,12 @@ def allocate_sequentially( class MultipackBatchSampler(BatchSampler): """ - Batch sampler class for efficient packing of variable-length sequences. + Batch sampler class for efficient packing of variable-length sequences This sampler packs sequences into fixed-capacity bins (batches) to maximize GPU memory utilization and training throughput by reducing padding. - It supports both length-optimized packing (using FFD algorithm) and + It supports both parallel packing (using FFD algorithm) and sequential packing (preserving original sequence order). """ @@ -212,29 +228,38 @@ class MultipackBatchSampler(BatchSampler): packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate drop_last: bool = False, # Whether to drop incomplete batches num_count_samples: int = 16, # Number of samples to estimate batch count - sequential: bool = False, # Whether to use sequential packing instead of FFD - **kwargs, + sequential: bool = False, # Whether to use sequential packing + group_size: int = 100_000, # Size of groups for parallel packing + num_processes: int | None = None, # Number of processes for parallel packing + safe_mode: bool = True, # Conservative packing to prevent training instability + **kwargs, # pylint: disable=unused-argument ): super().__init__(sampler, batch_size, drop_last) self.batch_size = batch_size self.batch_max_len = batch_max_len - self.lengths: np.ndarray = lengths + self.lengths = np.array(lengths, dtype=np.int32) self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.sequential = sequential + self.group_size = group_size + self.num_processes = num_processes + self.safe_mode = safe_mode assert isinstance(self.lengths, np.ndarray) self.epoch = 0 # Efficiency statistics tracking - self.eff_total_used = 0 # Total tokens used - self.eff_total_slots = 0 # Total token slots available + self.total_tokens_used = 0 + self.total_token_slots = 0 # The number of times to calculate batches to determine minimum packed dataset length self.num_count_samples = num_count_samples # Minimum packed dataset length across all ranks (determined by gather/broadcast) self.len_across_ranks = None + # Cache for batches + self._batches = None + if self.sequential and not isinstance(sampler, SequentialSampler): LOG.warn( "using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?" @@ -243,6 +268,7 @@ class MultipackBatchSampler(BatchSampler): def set_epoch(self, epoch: int): """Set the epoch number, used for reproducible shuffling across epochs""" self.epoch = epoch + self._batches = None # Invalidate batch cache def generate_batches(self, set_stats=False): """ @@ -255,44 +281,62 @@ class MultipackBatchSampler(BatchSampler): List of batches, where each batch contains multiple bins, and each bin contains multiple sequence indices """ + if self._batches is not None: + return self._batches + # Get indices from the sampler - indices = [idx for idx in self.sampler] + indices = [ # pylint: disable=unnecessary-comprehension + idx for idx in self.sampler + ] # Get lengths of the selected sequences lengths = self.lengths[indices] - lengths_cumsum = np.cumsum(lengths) - # Pack sequences into bins using either sequential or FFD allocation + # Pack sequences into bins using either sequential or parallel packing if self.sequential: bins, total_used, total_slots = allocate_sequentially( - lengths=lengths, + lengths, rank=0, bin_capacity=self.batch_max_len, num_ranks=1, ) else: - bins, total_used, total_slots = allocate( - lengths=lengths, - lengths_cumsum=lengths_cumsum, - rank=0, + # Use parallel packing + all_bins = pack_parallel( + lengths, bin_capacity=self.batch_max_len, - num_ranks=1, + group_size=self.group_size, + num_processes=self.num_processes, + safe_mode=self.safe_mode, ) + # Map bin indices back to original indices + bins = [ + [indices[b_idx] for b_idx in bin_indices] for bin_indices in all_bins + ] + + # Calculate efficiency statistics + total_used = lengths.sum() + total_slots = len(all_bins) * self.batch_max_len + # Group bins into batches (each batch contains batch_size bins) batches = [ - [ - [indices[b_idx] for b_idx in bin_indices] - for bin_indices in bins[i : i + self.batch_size] - ] - for i in range(0, len(bins), self.batch_size) + bins[i : i + self.batch_size] for i in range(0, len(bins), self.batch_size) ] + # Drop last batch if requested and it's incomplete + if self.drop_last and len(batches[-1]) < self.batch_size: + batches = batches[:-1] + # Adjust total_slots if we dropped a batch + if not self.sequential: + total_slots -= (self.batch_size - len(batches[-1])) * self.batch_max_len + # Update statistics if requested if set_stats: - self.eff_total_used += total_used - self.eff_total_slots += total_slots + self.total_tokens_used += total_used + self.total_token_slots += total_slots + self._batches = batches return batches def __iter__(self): @@ -308,17 +352,17 @@ class MultipackBatchSampler(BatchSampler): batches = batches[: self.len_across_ranks] return iter(batches) - def num_batches(self): - """Calculate the number of batches for this rank""" - batches = self.generate_batches(set_stats=True) - return len(batches) - def efficiency(self): """ Calculate the packing efficiency (ratio of tokens used to total token slots) Higher is better - 1.0 would mean perfect packing with no wasted space """ - return self.eff_total_used / self.eff_total_slots + if self.total_token_slots == 0: + self.generate_batches(set_stats=True) + if self.total_token_slots == 0: + return 0.0 + # Return a Python float instead of potentially a numpy float + return float(self.total_tokens_used / self.total_token_slots) def gather_efficiency(self): """ @@ -329,11 +373,12 @@ class MultipackBatchSampler(BatchSampler): def calc_sample_packing_eff_est(estimates: List[float]): LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}") # Use 99.7% of max observed efficiency as a safe estimate - return math.floor(0.997 * max(estimates)) + max_eff = max(float(eff) for eff in estimates) + return math.floor(0.997 * max_eff) # Gather efficiency from all ranks and apply the calculation function sample_packing_actual_eff_all = reduce_and_broadcast( - lambda: self.efficiency(), # pylint: disable=unnecessary-lambda + lambda: float(self.efficiency()), # pylint: disable=unnecessary-lambda calc_sample_packing_eff_est, ) @@ -364,11 +409,15 @@ class MultipackBatchSampler(BatchSampler): This is calculated as the minimum number of batches available on any rank to ensure balanced distributed training """ - if not self.len_across_ranks: + if self._batches is None: + self._batches = self.generate_batches(set_stats=True) + + if self.len_across_ranks is None: # Sample multiple times to get stable estimate - len_batches = min( - [self.num_batches() for _ in range(self.num_count_samples)] + len_batches = min( # pylint: disable=consider-using-generator + [len(self._batches) for _ in range(self.num_count_samples)] ) # Gather minimum across all ranks self.len_across_ranks = self.gather_len_batches(len_batches) + return self.len_across_ranks