improve readability of multipack sampler

This commit is contained in:
Wing Lian
2025-05-04 18:02:17 -04:00
parent 48b3e14a24
commit 03508c6816

View File

@@ -4,7 +4,7 @@ Multipack Batch Sampler
"""
import logging
import math
from typing import Any, Iterable, List, Union
from typing import Iterable, List, Union
import numba
import numpy as np
@@ -18,21 +18,27 @@ LOG.setLevel(logging.INFO)
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
# First-fit-decreasing bin packing algorithm
# Checks if sequences with lengths in sequence_lengths[] could fit in num_bins bins, each with capacity bin_capacity
# Returns True if all sequences can be packed, False otherwise
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
# Sort sequence lengths in descending order for optimal packing
sequence_lengths = np.sort(sequence_lengths)[::-1]
# Initialize all bins with full capacity
bins = np.full((num_bins,), bin_capacity, dtype=sequence_lengths.dtype)
# Try to place each sequence in the first bin it fits
for size in sequence_lengths:
not_found = True
for idx in range(n):
for idx in range(num_bins):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
# If no bin could fit this sequence, packing failed
if not_found:
return False
@@ -40,133 +46,173 @@ def ffd_check(a: np.ndarray, c: int, n: int):
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
def ffd_with_result(sequence_lengths: np.ndarray, bin_capacity: int, start_index: int):
# First-fit-decreasing bin packing that returns the actual bin assignments
# Returns a list of bins, where each bin contains indices of sequences assigned to it
indices = np.argsort(a)[::-1]
a = a[indices]
# Get sorting indices and sort sequence lengths in descending order
indices = np.argsort(sequence_lengths)[::-1]
sequence_lengths = sequence_lengths[indices]
bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
# Place each sequence in the first bin it fits
for seq_id, size in enumerate(sequence_lengths):
add_new_bin = True
for bin_idx in range(len(bins_remaining_space)):
if bins_remaining_space[bin_idx] >= size:
bins_remaining_space[bin_idx] -= size
bins_assigned_sequences[bin_idx].append(indices[seq_id] + start_index)
add_new_bin = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
# If no existing bin could fit this sequence, create a new bin
if add_new_bin:
bins_remaining_space.append(bin_capacity - size)
bins_assigned_sequences.append([indices[seq_id] + start_index])
return bins_result
return bins_assigned_sequences
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
sequence_lengths: np.ndarray,
lengths_cumsum: np.ndarray,
rank: int,
bin_capacity: int,
num_ranks: int,
):
# Dynamic batch allocator, similar to Multifit
# Dynamic batch allocator, similar to Multifit algorithm
# Efficiently packs sequences into fixed-capacity bins for distributed training
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
total_processed_tokens = 0
start_index = 0
result = []
rank_batches = [] # Batches assigned to the current rank
while True:
# binary search [l, r)
# Binary search to find maximum number of sequences that can be packed into num_ranks bins
# [left, right) defines the search range
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
right = 1 + np.searchsorted(
lengths_cumsum[start_index:],
total_processed_tokens + bin_capacity * num_ranks,
"right",
)
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
if ffd_check(
sequence_lengths[start_index : start_index + mid],
bin_capacity,
num_ranks,
):
left = mid
else:
right = mid
# use length l
batch = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
# Pack the identified sequences into bins
all_rank_batches = ffd_with_result(
sequence_lengths[start_index : start_index + left],
bin_capacity,
start_index,
)
assert len(batch) <= n
if len(batch) < n:
assert len(all_rank_batches) <= num_ranks
# If we couldn't fill all ranks, we're done
if len(all_rank_batches) < num_ranks:
break
# Update indices and processed token count
start_index += left
s = lengths_cumsum[start_index - 1]
total_processed_tokens = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
# Add the batch for the current rank
rank_batches.append(all_rank_batches[rank])
return result, s, len(result) * c * n
# Return batches for this rank, total tokens used, and total token slots available
return (
rank_batches,
total_processed_tokens,
len(rank_batches) * bin_capacity * num_ranks,
)
@numba.njit
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
def allocate_sequentially(
sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int
):
"""
Sequential allocator that preserves example order
Sequential allocator that preserves example order (no sorting by length)
Parameters:
- lengths: The lengths of all examples
Arguments:
- sequence_lengths: The lengths of all examples
- rank: The current rank (for distributed training)
- c: The capacity of each bin (maximum sequence length)
- n: Number of ranks
- bin_capacity: The capacity of each bin (maximum sequence length)
- num_ranks: Number of ranks (processes/GPUs)
Returns:
- result: List of batches for the current rank
- total_used: Number of actual example tokens
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
- rank_batches: List of batches for the current rank
- total_tokens_used: Number of actual example tokens
- total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
"""
result = []
total_used = 0
rank_batches = []
total_tokens_used = 0
# First, do sequential packing into bins
all_bins = []
current_bin = [0 for i in range(0)] # numba hint
remaining_capacity = c
current_bin = [0 for i in range(0)] # numba hint for empty list of integers
remaining_capacity = bin_capacity
for idx, size in enumerate(lengths):
# Process each sequence in order
for idx, size in enumerate(sequence_lengths):
if size <= remaining_capacity:
# Example fits in current bin
current_bin.append(idx)
remaining_capacity -= size
total_used += size
total_tokens_used += size
else:
# Example doesn't fit, start a new bin
if current_bin: # Add non-empty bin to all_bins
all_bins.append(current_bin)
current_bin = [idx]
remaining_capacity = c - size
total_used += size
remaining_capacity = bin_capacity - size
total_tokens_used += size
# Add the last bin if not empty
if current_bin:
all_bins.append(current_bin)
# Assign bins to ranks - each rank gets every n-th bin
for bin_idx in range(rank, len(all_bins), n):
result.append(all_bins[bin_idx])
# Assign bins to ranks - each rank gets every num_ranks-th bin
for bin_idx in range(rank, len(all_bins), num_ranks):
rank_batches.append(all_bins[bin_idx])
return result, total_used, len(all_bins) * c
return rank_batches, total_tokens_used, len(all_bins) * bin_capacity
class MultipackBatchSampler(BatchSampler):
"""Batch sampler class for multipack"""
"""
Batch sampler class for efficient packing of variable-length sequences.
This sampler packs sequences into fixed-capacity bins (batches) to maximize
GPU memory utilization and training throughput by reducing padding.
It supports both length-optimized packing (using FFD algorithm) and
sequential packing (preserving original sequence order).
"""
def __init__(
self,
sampler: Union[Sampler[int], Iterable[int]],
batch_size: int,
batch_max_len: int,
lengths: np.ndarray,
packing_efficiency_estimate: float = 1.0,
drop_last: bool = False,
num_count_samples: int = 16,
sequential: bool = False,
batch_size: int, # Number of bins per batch
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = False, # Whether to drop incomplete batches
num_count_samples: int = 16, # Number of samples to estimate batch count
sequential: bool = False, # Whether to use sequential packing instead of FFD
**kwargs,
):
super().__init__(sampler, batch_size, drop_last)
@@ -180,13 +226,13 @@ class MultipackBatchSampler(BatchSampler):
self.epoch = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
# Efficiency statistics tracking
self.eff_total_used = 0 # Total tokens used
self.eff_total_slots = 0 # Total token slots available
# The number of times to calculate the batches to determine the minimum packed dataset length for the local rank
# The number of times to calculate batches to determine minimum packed dataset length
self.num_count_samples = num_count_samples
# the minimum packed dataset length across all ranks determined by a gather/broadcast
# Minimum packed dataset length across all ranks (determined by gather/broadcast)
self.len_across_ranks = None
if self.sequential and not isinstance(sampler, SequentialSampler):
@@ -195,39 +241,54 @@ class MultipackBatchSampler(BatchSampler):
)
def set_epoch(self, epoch: int):
"""Set the epoch number, used for reproducible shuffling across epochs"""
self.epoch = epoch
def generate_batches(self, set_stats=False):
"""
Generate packed batches for training
Args:
set_stats: Whether to update efficiency statistics
Returns:
List of batches, where each batch contains multiple bins,
and each bin contains multiple sequence indices
"""
# Get indices from the sampler
indices = [idx for idx in self.sampler]
# Get lengths of the selected sequences
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
# Pack sequences into bins using either sequential or FFD allocation
if self.sequential:
batches, total_used, total_slots = allocate_sequentially(
bins, total_used, total_slots = allocate_sequentially(
lengths=lengths,
rank=0,
c=self.batch_max_len,
n=1,
bin_capacity=self.batch_max_len,
num_ranks=1,
)
else:
batches, total_used, total_slots = allocate(
bins, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
bin_capacity=self.batch_max_len,
num_ranks=1,
)
# Group bins into batches (each batch contains batch_size bins)
batches = [
[
[indices[b_idx] for b_idx in batch]
for batch in batches[i : i + self.batch_size]
[indices[b_idx] for b_idx in bin_indices]
for bin_indices in bins[i : i + self.batch_size]
]
for i in range(0, len(batches), self.batch_size)
for i in range(0, len(bins), self.batch_size)
]
# statistics
# Update statistics if requested
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
@@ -235,45 +296,79 @@ class MultipackBatchSampler(BatchSampler):
return batches
def __iter__(self):
"""
Return an iterator over batches
The batches are truncated to match the minimum number of batches across all ranks
to ensure distributed training balance
"""
batches = self.generate_batches(set_stats=True)
if self.len_across_ranks:
# make sure the batches we iterate over is truncated to the same min length across all ranks
# Truncate batches to ensure all ranks have the same number of batches
batches = batches[: self.len_across_ranks]
return iter(batches)
def num_batches(self):
"""Calculate the number of batches for this rank"""
batches = self.generate_batches(set_stats=True)
return len(batches)
def efficiency(self):
"""
Calculate the packing efficiency (ratio of tokens used to total token slots)
Higher is better - 1.0 would mean perfect packing with no wasted space
"""
return self.eff_total_used / self.eff_total_slots
def gather_efficiency(self):
"""
Gather and synchronize packing efficiency estimates across all distributed ranks
Returns a conservative efficiency estimate based on the measurements
"""
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
# Use 99.7% of max observed efficiency as a safe estimate
return math.floor(0.997 * max(estimates))
# Gather efficiency from all ranks and apply the calculation function
sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
calc_sample_packing_eff_est,
)
# Quantize to 0.5% intervals for stability
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
)
return sample_packing_eff_est
def gather_len_batches(self, num):
"""
Gather and synchronize batch counts across all distributed ranks
Returns the minimum number of batches available on any rank
"""
def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(min(estimates))
# Find minimum batch count across ranks to ensure balance
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches
def __len__(self):
"""
Return the total number of batches that will be yielded by this sampler
This is calculated as the minimum number of batches available on any rank
to ensure balanced distributed training
"""
if not self.len_across_ranks:
# Sample multiple times to get stable estimate
len_batches = min(
[self.num_batches() for _ in range(self.num_count_samples)]
)
# Gather minimum across all ranks
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks