From 03508c6816375801b6da89ce55882f8cf85ccc50 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 4 May 2025 18:02:17 -0400 Subject: [PATCH] improve readability of multipack sampler --- src/axolotl/utils/samplers/multipack.py | 267 ++++++++++++++++-------- 1 file changed, 181 insertions(+), 86 deletions(-) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 0d8806d8b..e419486e2 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -4,7 +4,7 @@ Multipack Batch Sampler """ import logging import math -from typing import Any, Iterable, List, Union +from typing import Iterable, List, Union import numba import numpy as np @@ -18,21 +18,27 @@ LOG.setLevel(logging.INFO) @numba.njit -def ffd_check(a: np.ndarray, c: int, n: int): - # First-fit-decreasing bin packing - # Check if a[] could fit in n bins with capacity c +def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int): + # First-fit-decreasing bin packing algorithm + # Checks if sequences with lengths in sequence_lengths[] could fit in num_bins bins, each with capacity bin_capacity + # Returns True if all sequences can be packed, False otherwise # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing - a = np.sort(a)[::-1] - bins = np.full((n,), c, dtype=a.dtype) - for size in a: + # Sort sequence lengths in descending order for optimal packing + sequence_lengths = np.sort(sequence_lengths)[::-1] + # Initialize all bins with full capacity + bins = np.full((num_bins,), bin_capacity, dtype=sequence_lengths.dtype) + + # Try to place each sequence in the first bin it fits + for size in sequence_lengths: not_found = True - for idx in range(n): + for idx in range(num_bins): if bins[idx] >= size: bins[idx] -= size not_found = False break + # If no bin could fit this sequence, packing failed if not_found: return False @@ -40,133 +46,173 @@ def ffd_check(a: np.ndarray, c: int, n: int): @numba.njit -def ffd_with_result(a: np.ndarray, c: int, start_index: int): - # First-fit-decreasing bin packing (with result return) +def ffd_with_result(sequence_lengths: np.ndarray, bin_capacity: int, start_index: int): + # First-fit-decreasing bin packing that returns the actual bin assignments + # Returns a list of bins, where each bin contains indices of sequences assigned to it - indices = np.argsort(a)[::-1] - a = a[indices] + # Get sorting indices and sort sequence lengths in descending order + indices = np.argsort(sequence_lengths)[::-1] + sequence_lengths = sequence_lengths[indices] - bins: List[Any] = [] - bins_result: List[Any] = [] - for a_id, size in enumerate(a): - add_new = True - for idx in range(len(bins)): - if bins[idx] >= size: - bins[idx] -= size - bins_result[idx].append(indices[a_id] + start_index) - add_new = False + bins_remaining_space: list = [] # Tracks remaining capacity in each bin + bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin + + # Place each sequence in the first bin it fits + for seq_id, size in enumerate(sequence_lengths): + add_new_bin = True + for bin_idx in range(len(bins_remaining_space)): + if bins_remaining_space[bin_idx] >= size: + bins_remaining_space[bin_idx] -= size + bins_assigned_sequences[bin_idx].append(indices[seq_id] + start_index) + add_new_bin = False break - if add_new: - bins.append(c - size) - bins_result.append([indices[a_id] + start_index]) + # If no existing bin could fit this sequence, create a new bin + if add_new_bin: + bins_remaining_space.append(bin_capacity - size) + bins_assigned_sequences.append([indices[seq_id] + start_index]) - return bins_result + return bins_assigned_sequences @numba.njit def allocate( - lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int + sequence_lengths: np.ndarray, + lengths_cumsum: np.ndarray, + rank: int, + bin_capacity: int, + num_ranks: int, ): - # Dynamic batch allocator, similar to Multifit + # Dynamic batch allocator, similar to Multifit algorithm + # Efficiently packs sequences into fixed-capacity bins for distributed training # https://en.wikipedia.org/wiki/Multifit_algorithm # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) - s = 0 + total_processed_tokens = 0 start_index = 0 - result = [] + rank_batches = [] # Batches assigned to the current rank while True: - # binary search [l, r) + # Binary search to find maximum number of sequences that can be packed into num_ranks bins + # [left, right) defines the search range left = 1 - right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") + right = 1 + np.searchsorted( + lengths_cumsum[start_index:], + total_processed_tokens + bin_capacity * num_ranks, + "right", + ) while right - left > 1: mid = (left + right) // 2 - if ffd_check(lengths[start_index : start_index + mid], c, n): + if ffd_check( + sequence_lengths[start_index : start_index + mid], + bin_capacity, + num_ranks, + ): left = mid else: right = mid - # use length l - batch = ffd_with_result( - lengths[start_index : start_index + left], c, start_index + # Pack the identified sequences into bins + all_rank_batches = ffd_with_result( + sequence_lengths[start_index : start_index + left], + bin_capacity, + start_index, ) - assert len(batch) <= n - if len(batch) < n: + assert len(all_rank_batches) <= num_ranks + + # If we couldn't fill all ranks, we're done + if len(all_rank_batches) < num_ranks: break + # Update indices and processed token count start_index += left - s = lengths_cumsum[start_index - 1] + total_processed_tokens = lengths_cumsum[start_index - 1] - # add local rank - result.append(batch[rank]) + # Add the batch for the current rank + rank_batches.append(all_rank_batches[rank]) - return result, s, len(result) * c * n + # Return batches for this rank, total tokens used, and total token slots available + return ( + rank_batches, + total_processed_tokens, + len(rank_batches) * bin_capacity * num_ranks, + ) @numba.njit -def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int): +def allocate_sequentially( + sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int +): """ - Sequential allocator that preserves example order + Sequential allocator that preserves example order (no sorting by length) - Parameters: - - lengths: The lengths of all examples + Arguments: + - sequence_lengths: The lengths of all examples - rank: The current rank (for distributed training) - - c: The capacity of each bin (maximum sequence length) - - n: Number of ranks + - bin_capacity: The capacity of each bin (maximum sequence length) + - num_ranks: Number of ranks (processes/GPUs) Returns: - - result: List of batches for the current rank - - total_used: Number of actual example tokens - - total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity) + - rank_batches: List of batches for the current rank + - total_tokens_used: Number of actual example tokens + - total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity) """ - result = [] - total_used = 0 + rank_batches = [] + total_tokens_used = 0 # First, do sequential packing into bins all_bins = [] - current_bin = [0 for i in range(0)] # numba hint - remaining_capacity = c + current_bin = [0 for i in range(0)] # numba hint for empty list of integers + remaining_capacity = bin_capacity - for idx, size in enumerate(lengths): + # Process each sequence in order + for idx, size in enumerate(sequence_lengths): if size <= remaining_capacity: # Example fits in current bin current_bin.append(idx) remaining_capacity -= size - total_used += size + total_tokens_used += size else: # Example doesn't fit, start a new bin if current_bin: # Add non-empty bin to all_bins all_bins.append(current_bin) current_bin = [idx] - remaining_capacity = c - size - total_used += size + remaining_capacity = bin_capacity - size + total_tokens_used += size # Add the last bin if not empty if current_bin: all_bins.append(current_bin) - # Assign bins to ranks - each rank gets every n-th bin - for bin_idx in range(rank, len(all_bins), n): - result.append(all_bins[bin_idx]) + # Assign bins to ranks - each rank gets every num_ranks-th bin + for bin_idx in range(rank, len(all_bins), num_ranks): + rank_batches.append(all_bins[bin_idx]) - return result, total_used, len(all_bins) * c + return rank_batches, total_tokens_used, len(all_bins) * bin_capacity class MultipackBatchSampler(BatchSampler): - """Batch sampler class for multipack""" + """ + Batch sampler class for efficient packing of variable-length sequences. + + This sampler packs sequences into fixed-capacity bins (batches) to maximize + GPU memory utilization and training throughput by reducing padding. + + It supports both length-optimized packing (using FFD algorithm) and + sequential packing (preserving original sequence order). + """ def __init__( self, sampler: Union[Sampler[int], Iterable[int]], - batch_size: int, - batch_max_len: int, - lengths: np.ndarray, - packing_efficiency_estimate: float = 1.0, - drop_last: bool = False, - num_count_samples: int = 16, - sequential: bool = False, + batch_size: int, # Number of bins per batch + batch_max_len: int, # Maximum sequence length (bin capacity) + lengths: np.ndarray, # Sequence lengths + packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate + drop_last: bool = False, # Whether to drop incomplete batches + num_count_samples: int = 16, # Number of samples to estimate batch count + sequential: bool = False, # Whether to use sequential packing instead of FFD **kwargs, ): super().__init__(sampler, batch_size, drop_last) @@ -180,13 +226,13 @@ class MultipackBatchSampler(BatchSampler): self.epoch = 0 - # statistics - self.eff_total_used = 0 - self.eff_total_slots = 0 + # Efficiency statistics tracking + self.eff_total_used = 0 # Total tokens used + self.eff_total_slots = 0 # Total token slots available - # The number of times to calculate the batches to determine the minimum packed dataset length for the local rank + # The number of times to calculate batches to determine minimum packed dataset length self.num_count_samples = num_count_samples - # the minimum packed dataset length across all ranks determined by a gather/broadcast + # Minimum packed dataset length across all ranks (determined by gather/broadcast) self.len_across_ranks = None if self.sequential and not isinstance(sampler, SequentialSampler): @@ -195,39 +241,54 @@ class MultipackBatchSampler(BatchSampler): ) def set_epoch(self, epoch: int): + """Set the epoch number, used for reproducible shuffling across epochs""" self.epoch = epoch def generate_batches(self, set_stats=False): + """ + Generate packed batches for training + + Args: + set_stats: Whether to update efficiency statistics + + Returns: + List of batches, where each batch contains multiple bins, + and each bin contains multiple sequence indices + """ + # Get indices from the sampler indices = [idx for idx in self.sampler] + # Get lengths of the selected sequences lengths = self.lengths[indices] lengths_cumsum = np.cumsum(lengths) + # Pack sequences into bins using either sequential or FFD allocation if self.sequential: - batches, total_used, total_slots = allocate_sequentially( + bins, total_used, total_slots = allocate_sequentially( lengths=lengths, rank=0, - c=self.batch_max_len, - n=1, + bin_capacity=self.batch_max_len, + num_ranks=1, ) else: - batches, total_used, total_slots = allocate( + bins, total_used, total_slots = allocate( lengths=lengths, lengths_cumsum=lengths_cumsum, rank=0, - c=self.batch_max_len, - n=1, + bin_capacity=self.batch_max_len, + num_ranks=1, ) + # Group bins into batches (each batch contains batch_size bins) batches = [ [ - [indices[b_idx] for b_idx in batch] - for batch in batches[i : i + self.batch_size] + [indices[b_idx] for b_idx in bin_indices] + for bin_indices in bins[i : i + self.batch_size] ] - for i in range(0, len(batches), self.batch_size) + for i in range(0, len(bins), self.batch_size) ] - # statistics + # Update statistics if requested if set_stats: self.eff_total_used += total_used self.eff_total_slots += total_slots @@ -235,45 +296,79 @@ class MultipackBatchSampler(BatchSampler): return batches def __iter__(self): + """ + Return an iterator over batches + + The batches are truncated to match the minimum number of batches across all ranks + to ensure distributed training balance + """ batches = self.generate_batches(set_stats=True) if self.len_across_ranks: - # make sure the batches we iterate over is truncated to the same min length across all ranks + # Truncate batches to ensure all ranks have the same number of batches batches = batches[: self.len_across_ranks] return iter(batches) def num_batches(self): + """Calculate the number of batches for this rank""" batches = self.generate_batches(set_stats=True) return len(batches) def efficiency(self): + """ + Calculate the packing efficiency (ratio of tokens used to total token slots) + Higher is better - 1.0 would mean perfect packing with no wasted space + """ return self.eff_total_used / self.eff_total_slots def gather_efficiency(self): + """ + Gather and synchronize packing efficiency estimates across all distributed ranks + Returns a conservative efficiency estimate based on the measurements + """ + def calc_sample_packing_eff_est(estimates: List[float]): LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}") + # Use 99.7% of max observed efficiency as a safe estimate return math.floor(0.997 * max(estimates)) + # Gather efficiency from all ranks and apply the calculation function sample_packing_actual_eff_all = reduce_and_broadcast( lambda: self.efficiency(), # pylint: disable=unnecessary-lambda calc_sample_packing_eff_est, ) + + # Quantize to 0.5% intervals for stability sample_packing_eff_est = ( math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0 ) return sample_packing_eff_est def gather_len_batches(self, num): + """ + Gather and synchronize batch counts across all distributed ranks + Returns the minimum number of batches available on any rank + """ + def calc_min_len(estimates: list[(int, float)]): LOG.info(f"gather_len_batches: {repr(estimates)}") return math.floor(min(estimates)) + # Find minimum batch count across ranks to ensure balance min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len) return min_len_batches def __len__(self): + """ + Return the total number of batches that will be yielded by this sampler + + This is calculated as the minimum number of batches available on any rank + to ensure balanced distributed training + """ if not self.len_across_ranks: + # Sample multiple times to get stable estimate len_batches = min( [self.num_batches() for _ in range(self.num_count_samples)] ) + # Gather minimum across all ranks self.len_across_ranks = self.gather_len_batches(len_batches) return self.len_across_ranks