improve readability of multipack sampler

This commit is contained in:
Wing Lian
2025-05-04 18:02:17 -04:00
parent 48b3e14a24
commit 03508c6816

View File

@@ -4,7 +4,7 @@ Multipack Batch Sampler
""" """
import logging import logging
import math import math
from typing import Any, Iterable, List, Union from typing import Iterable, List, Union
import numba import numba
import numpy as np import numpy as np
@@ -18,21 +18,27 @@ LOG.setLevel(logging.INFO)
@numba.njit @numba.njit
def ffd_check(a: np.ndarray, c: int, n: int): def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
# First-fit-decreasing bin packing # First-fit-decreasing bin packing algorithm
# Check if a[] could fit in n bins with capacity c # Checks if sequences with lengths in sequence_lengths[] could fit in num_bins bins, each with capacity bin_capacity
# Returns True if all sequences can be packed, False otherwise
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
a = np.sort(a)[::-1] # Sort sequence lengths in descending order for optimal packing
bins = np.full((n,), c, dtype=a.dtype) sequence_lengths = np.sort(sequence_lengths)[::-1]
for size in a: # Initialize all bins with full capacity
bins = np.full((num_bins,), bin_capacity, dtype=sequence_lengths.dtype)
# Try to place each sequence in the first bin it fits
for size in sequence_lengths:
not_found = True not_found = True
for idx in range(n): for idx in range(num_bins):
if bins[idx] >= size: if bins[idx] >= size:
bins[idx] -= size bins[idx] -= size
not_found = False not_found = False
break break
# If no bin could fit this sequence, packing failed
if not_found: if not_found:
return False return False
@@ -40,133 +46,173 @@ def ffd_check(a: np.ndarray, c: int, n: int):
@numba.njit @numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int): def ffd_with_result(sequence_lengths: np.ndarray, bin_capacity: int, start_index: int):
# First-fit-decreasing bin packing (with result return) # First-fit-decreasing bin packing that returns the actual bin assignments
# Returns a list of bins, where each bin contains indices of sequences assigned to it
indices = np.argsort(a)[::-1] # Get sorting indices and sort sequence lengths in descending order
a = a[indices] indices = np.argsort(sequence_lengths)[::-1]
sequence_lengths = sequence_lengths[indices]
bins: List[Any] = [] bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_result: List[Any] = [] bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
for a_id, size in enumerate(a):
add_new = True # Place each sequence in the first bin it fits
for idx in range(len(bins)): for seq_id, size in enumerate(sequence_lengths):
if bins[idx] >= size: add_new_bin = True
bins[idx] -= size for bin_idx in range(len(bins_remaining_space)):
bins_result[idx].append(indices[a_id] + start_index) if bins_remaining_space[bin_idx] >= size:
add_new = False bins_remaining_space[bin_idx] -= size
bins_assigned_sequences[bin_idx].append(indices[seq_id] + start_index)
add_new_bin = False
break break
if add_new: # If no existing bin could fit this sequence, create a new bin
bins.append(c - size) if add_new_bin:
bins_result.append([indices[a_id] + start_index]) bins_remaining_space.append(bin_capacity - size)
bins_assigned_sequences.append([indices[seq_id] + start_index])
return bins_result return bins_assigned_sequences
@numba.njit @numba.njit
def allocate( def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int sequence_lengths: np.ndarray,
lengths_cumsum: np.ndarray,
rank: int,
bin_capacity: int,
num_ranks: int,
): ):
# Dynamic batch allocator, similar to Multifit # Dynamic batch allocator, similar to Multifit algorithm
# Efficiently packs sequences into fixed-capacity bins for distributed training
# https://en.wikipedia.org/wiki/Multifit_algorithm # https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0 total_processed_tokens = 0
start_index = 0 start_index = 0
result = [] rank_batches = [] # Batches assigned to the current rank
while True: while True:
# binary search [l, r) # Binary search to find maximum number of sequences that can be packed into num_ranks bins
# [left, right) defines the search range
left = 1 left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") right = 1 + np.searchsorted(
lengths_cumsum[start_index:],
total_processed_tokens + bin_capacity * num_ranks,
"right",
)
while right - left > 1: while right - left > 1:
mid = (left + right) // 2 mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n): if ffd_check(
sequence_lengths[start_index : start_index + mid],
bin_capacity,
num_ranks,
):
left = mid left = mid
else: else:
right = mid right = mid
# use length l # Pack the identified sequences into bins
batch = ffd_with_result( all_rank_batches = ffd_with_result(
lengths[start_index : start_index + left], c, start_index sequence_lengths[start_index : start_index + left],
bin_capacity,
start_index,
) )
assert len(batch) <= n assert len(all_rank_batches) <= num_ranks
if len(batch) < n:
# If we couldn't fill all ranks, we're done
if len(all_rank_batches) < num_ranks:
break break
# Update indices and processed token count
start_index += left start_index += left
s = lengths_cumsum[start_index - 1] total_processed_tokens = lengths_cumsum[start_index - 1]
# add local rank # Add the batch for the current rank
result.append(batch[rank]) rank_batches.append(all_rank_batches[rank])
return result, s, len(result) * c * n # Return batches for this rank, total tokens used, and total token slots available
return (
rank_batches,
total_processed_tokens,
len(rank_batches) * bin_capacity * num_ranks,
)
@numba.njit @numba.njit
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int): def allocate_sequentially(
sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int
):
""" """
Sequential allocator that preserves example order Sequential allocator that preserves example order (no sorting by length)
Parameters: Arguments:
- lengths: The lengths of all examples - sequence_lengths: The lengths of all examples
- rank: The current rank (for distributed training) - rank: The current rank (for distributed training)
- c: The capacity of each bin (maximum sequence length) - bin_capacity: The capacity of each bin (maximum sequence length)
- n: Number of ranks - num_ranks: Number of ranks (processes/GPUs)
Returns: Returns:
- result: List of batches for the current rank - rank_batches: List of batches for the current rank
- total_used: Number of actual example tokens - total_tokens_used: Number of actual example tokens
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity) - total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
""" """
result = [] rank_batches = []
total_used = 0 total_tokens_used = 0
# First, do sequential packing into bins # First, do sequential packing into bins
all_bins = [] all_bins = []
current_bin = [0 for i in range(0)] # numba hint current_bin = [0 for i in range(0)] # numba hint for empty list of integers
remaining_capacity = c remaining_capacity = bin_capacity
for idx, size in enumerate(lengths): # Process each sequence in order
for idx, size in enumerate(sequence_lengths):
if size <= remaining_capacity: if size <= remaining_capacity:
# Example fits in current bin # Example fits in current bin
current_bin.append(idx) current_bin.append(idx)
remaining_capacity -= size remaining_capacity -= size
total_used += size total_tokens_used += size
else: else:
# Example doesn't fit, start a new bin # Example doesn't fit, start a new bin
if current_bin: # Add non-empty bin to all_bins if current_bin: # Add non-empty bin to all_bins
all_bins.append(current_bin) all_bins.append(current_bin)
current_bin = [idx] current_bin = [idx]
remaining_capacity = c - size remaining_capacity = bin_capacity - size
total_used += size total_tokens_used += size
# Add the last bin if not empty # Add the last bin if not empty
if current_bin: if current_bin:
all_bins.append(current_bin) all_bins.append(current_bin)
# Assign bins to ranks - each rank gets every n-th bin # Assign bins to ranks - each rank gets every num_ranks-th bin
for bin_idx in range(rank, len(all_bins), n): for bin_idx in range(rank, len(all_bins), num_ranks):
result.append(all_bins[bin_idx]) rank_batches.append(all_bins[bin_idx])
return result, total_used, len(all_bins) * c return rank_batches, total_tokens_used, len(all_bins) * bin_capacity
class MultipackBatchSampler(BatchSampler): class MultipackBatchSampler(BatchSampler):
"""Batch sampler class for multipack""" """
Batch sampler class for efficient packing of variable-length sequences.
This sampler packs sequences into fixed-capacity bins (batches) to maximize
GPU memory utilization and training throughput by reducing padding.
It supports both length-optimized packing (using FFD algorithm) and
sequential packing (preserving original sequence order).
"""
def __init__( def __init__(
self, self,
sampler: Union[Sampler[int], Iterable[int]], sampler: Union[Sampler[int], Iterable[int]],
batch_size: int, batch_size: int, # Number of bins per batch
batch_max_len: int, batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = False, drop_last: bool = False, # Whether to drop incomplete batches
num_count_samples: int = 16, num_count_samples: int = 16, # Number of samples to estimate batch count
sequential: bool = False, sequential: bool = False, # Whether to use sequential packing instead of FFD
**kwargs, **kwargs,
): ):
super().__init__(sampler, batch_size, drop_last) super().__init__(sampler, batch_size, drop_last)
@@ -180,13 +226,13 @@ class MultipackBatchSampler(BatchSampler):
self.epoch = 0 self.epoch = 0
# statistics # Efficiency statistics tracking
self.eff_total_used = 0 self.eff_total_used = 0 # Total tokens used
self.eff_total_slots = 0 self.eff_total_slots = 0 # Total token slots available
# The number of times to calculate the batches to determine the minimum packed dataset length for the local rank # The number of times to calculate batches to determine minimum packed dataset length
self.num_count_samples = num_count_samples self.num_count_samples = num_count_samples
# the minimum packed dataset length across all ranks determined by a gather/broadcast # Minimum packed dataset length across all ranks (determined by gather/broadcast)
self.len_across_ranks = None self.len_across_ranks = None
if self.sequential and not isinstance(sampler, SequentialSampler): if self.sequential and not isinstance(sampler, SequentialSampler):
@@ -195,39 +241,54 @@ class MultipackBatchSampler(BatchSampler):
) )
def set_epoch(self, epoch: int): def set_epoch(self, epoch: int):
"""Set the epoch number, used for reproducible shuffling across epochs"""
self.epoch = epoch self.epoch = epoch
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats=False):
"""
Generate packed batches for training
Args:
set_stats: Whether to update efficiency statistics
Returns:
List of batches, where each batch contains multiple bins,
and each bin contains multiple sequence indices
"""
# Get indices from the sampler
indices = [idx for idx in self.sampler] indices = [idx for idx in self.sampler]
# Get lengths of the selected sequences
lengths = self.lengths[indices] lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths) lengths_cumsum = np.cumsum(lengths)
# Pack sequences into bins using either sequential or FFD allocation
if self.sequential: if self.sequential:
batches, total_used, total_slots = allocate_sequentially( bins, total_used, total_slots = allocate_sequentially(
lengths=lengths, lengths=lengths,
rank=0, rank=0,
c=self.batch_max_len, bin_capacity=self.batch_max_len,
n=1, num_ranks=1,
) )
else: else:
batches, total_used, total_slots = allocate( bins, total_used, total_slots = allocate(
lengths=lengths, lengths=lengths,
lengths_cumsum=lengths_cumsum, lengths_cumsum=lengths_cumsum,
rank=0, rank=0,
c=self.batch_max_len, bin_capacity=self.batch_max_len,
n=1, num_ranks=1,
) )
# Group bins into batches (each batch contains batch_size bins)
batches = [ batches = [
[ [
[indices[b_idx] for b_idx in batch] [indices[b_idx] for b_idx in bin_indices]
for batch in batches[i : i + self.batch_size] for bin_indices in bins[i : i + self.batch_size]
] ]
for i in range(0, len(batches), self.batch_size) for i in range(0, len(bins), self.batch_size)
] ]
# statistics # Update statistics if requested
if set_stats: if set_stats:
self.eff_total_used += total_used self.eff_total_used += total_used
self.eff_total_slots += total_slots self.eff_total_slots += total_slots
@@ -235,45 +296,79 @@ class MultipackBatchSampler(BatchSampler):
return batches return batches
def __iter__(self): def __iter__(self):
"""
Return an iterator over batches
The batches are truncated to match the minimum number of batches across all ranks
to ensure distributed training balance
"""
batches = self.generate_batches(set_stats=True) batches = self.generate_batches(set_stats=True)
if self.len_across_ranks: if self.len_across_ranks:
# make sure the batches we iterate over is truncated to the same min length across all ranks # Truncate batches to ensure all ranks have the same number of batches
batches = batches[: self.len_across_ranks] batches = batches[: self.len_across_ranks]
return iter(batches) return iter(batches)
def num_batches(self): def num_batches(self):
"""Calculate the number of batches for this rank"""
batches = self.generate_batches(set_stats=True) batches = self.generate_batches(set_stats=True)
return len(batches) return len(batches)
def efficiency(self): def efficiency(self):
"""
Calculate the packing efficiency (ratio of tokens used to total token slots)
Higher is better - 1.0 would mean perfect packing with no wasted space
"""
return self.eff_total_used / self.eff_total_slots return self.eff_total_used / self.eff_total_slots
def gather_efficiency(self): def gather_efficiency(self):
"""
Gather and synchronize packing efficiency estimates across all distributed ranks
Returns a conservative efficiency estimate based on the measurements
"""
def calc_sample_packing_eff_est(estimates: List[float]): def calc_sample_packing_eff_est(estimates: List[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}") LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
# Use 99.7% of max observed efficiency as a safe estimate
return math.floor(0.997 * max(estimates)) return math.floor(0.997 * max(estimates))
# Gather efficiency from all ranks and apply the calculation function
sample_packing_actual_eff_all = reduce_and_broadcast( sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
calc_sample_packing_eff_est, calc_sample_packing_eff_est,
) )
# Quantize to 0.5% intervals for stability
sample_packing_eff_est = ( sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0 math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
) )
return sample_packing_eff_est return sample_packing_eff_est
def gather_len_batches(self, num): def gather_len_batches(self, num):
"""
Gather and synchronize batch counts across all distributed ranks
Returns the minimum number of batches available on any rank
"""
def calc_min_len(estimates: list[(int, float)]): def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}") LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(min(estimates)) return math.floor(min(estimates))
# Find minimum batch count across ranks to ensure balance
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len) min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches return min_len_batches
def __len__(self): def __len__(self):
"""
Return the total number of batches that will be yielded by this sampler
This is calculated as the minimum number of batches available on any rank
to ensure balanced distributed training
"""
if not self.len_across_ranks: if not self.len_across_ranks:
# Sample multiple times to get stable estimate
len_batches = min( len_batches = min(
[self.num_batches() for _ in range(self.num_count_samples)] [self.num_batches() for _ in range(self.num_count_samples)]
) )
# Gather minimum across all ranks
self.len_across_ranks = self.gather_len_batches(len_batches) self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks return self.len_across_ranks