Sequential sample packing (#2404) [skip ci]

* add sequential sample packing

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
DreamGenX
2025-03-31 21:48:20 +02:00
committed by GitHub
parent 7acf93b59f
commit 4d36ecc724
7 changed files with 174 additions and 11 deletions

View File

@@ -112,6 +112,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
sequential=self.args.sample_packing_sequentially,
drop_last=True,
)

View File

@@ -34,6 +34,12 @@ class AxolotlTrainingMixins:
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},

View File

@@ -8,7 +8,7 @@ from typing import Any, Iterable, List, Union
import numba
import numpy as np
from torch.utils.data import BatchSampler, Sampler
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from axolotl.utils.distributed import reduce_and_broadcast
@@ -103,6 +103,55 @@ def allocate(
return result, s, len(result) * c * n
@numba.njit
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
"""
Sequential allocator that preserves example order
Parameters:
- lengths: The lengths of all examples
- rank: The current rank (for distributed training)
- c: The capacity of each bin (maximum sequence length)
- n: Number of ranks
Returns:
- result: List of batches for the current rank
- total_used: Number of actual example tokens
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
"""
result = []
total_used = 0
# First, do sequential packing into bins
all_bins = []
current_bin = [0 for i in range(0)] # numba hint
remaining_capacity = c
for idx, size in enumerate(lengths):
if size <= remaining_capacity:
# Example fits in current bin
current_bin.append(idx)
remaining_capacity -= size
total_used += size
else:
# Example doesn't fit, start a new bin
if current_bin: # Add non-empty bin to all_bins
all_bins.append(current_bin)
current_bin = [idx]
remaining_capacity = c - size
total_used += size
# Add the last bin if not empty
if current_bin:
all_bins.append(current_bin)
# Assign bins to ranks - each rank gets every n-th bin
for bin_idx in range(rank, len(all_bins), n):
result.append(all_bins[bin_idx])
return result, total_used, len(all_bins) * c
class MultipackBatchSampler(BatchSampler):
"""Batch sampler class for multipack"""
@@ -115,6 +164,7 @@ class MultipackBatchSampler(BatchSampler):
packing_efficiency_estimate: float = 1.0,
drop_last: bool = False,
num_count_samples: int = 16,
sequential: bool = False,
**kwargs,
):
super().__init__(sampler, batch_size, drop_last)
@@ -122,6 +172,7 @@ class MultipackBatchSampler(BatchSampler):
self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.sequential = sequential
assert isinstance(self.lengths, np.ndarray)
@@ -136,6 +187,11 @@ class MultipackBatchSampler(BatchSampler):
# the minimum packed dataset length across all ranks determined by a gather/broadcast
self.len_across_ranks = None
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warn(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
)
def set_epoch(self, epoch: int):
self.epoch = epoch
@@ -145,13 +201,23 @@ class MultipackBatchSampler(BatchSampler):
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
if self.sequential:
LOG.debug("using sequential sample packing algorithm")
batches, total_used, total_slots = allocate_sequentially(
lengths=lengths,
rank=0,
c=self.batch_max_len,
n=1,
)
else:
LOG.debug("using non-sequential sample packing algorithm")
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
batches = [
[

View File

@@ -192,6 +192,7 @@ class AxolotlInputConfig(
sample_packing: bool | None = None
sample_packing_group_size: int | None = 100_000
sample_packing_bin_size: int | None = 200
sample_packing_sequentially: bool | None = None
eval_sample_packing: bool | None = None
pad_to_sequence_len: bool | None = None
curriculum_sampling: bool | None = None

View File

@@ -13,7 +13,7 @@ import torch
import torch.cuda
from accelerate.logging import get_logger
from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -456,13 +456,18 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
else:
sampler_batch_size = cfg.micro_batch_size
batch_max_len = cfg.sequence_len
if cfg.curriculum_sampling:
sampler = SequentialSampler(train_dataset)
else:
sampler = RandomSampler(train_dataset)
sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
sampler=sampler,
lengths=get_dataset_lengths(train_dataset),
batch_size=sampler_batch_size,
batch_max_len=batch_max_len,
group_size=cfg.sample_packing_group_size,
bin_size=cfg.sample_packing_bin_size,
sequential=cfg.sample_packing_sequentially,
drop_last=True,
)