diff --git a/docs/config.qmd b/docs/config.qmd index 570a173f9..bc44964dc 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -186,6 +186,11 @@ eval_sample_packing: # The trainer will provide recommended values for these values. sample_packing_eff_est: total_num_tokens: +# Increasing the following values helps with packing, but usually only slightly (<%1.) +# The number of samples packed at a time. +sample_packing_group_size: 100000 +# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples. +sample_packing_bin_size: 200 # Passed through to transformers when loading the model when launched without accelerate # Use `sequential` when training w/ model parallelism to limit memory diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index c510c8e10..a37652ade 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -125,14 +125,22 @@ class AxolotlTrainingArguments(TrainingArguments): default=1.0, metadata={"help": "Sample packing efficiency for calculating batch length."}, ) + sample_packing_bin_size: int = field( + default=200, + metadata={ + "help": "The max number of samples that packed sample can contain after packing. Increase for better packing." + }, + ) + sample_packing_group_size: int = field( + default=100000, + metadata={ + "help": "The number of samples to group together for packing. Increase for better packing." + }, + ) max_seq_length: int = field( default=2048, metadata={"help": "The maximum sequence length the model can handle"}, ) - sample_packing_seq_len_multiplier: int = field( - default=1, - metadata={"help": "the multiplier for the max len for packed sequences"}, - ) relora_steps: Optional[int] = field( default=None, metadata={"help": "how often to reset for ReLoRA"}, @@ -346,11 +354,11 @@ class AxolotlTrainer(Trainer): ) return MultipackBatchSampler( RandomSampler(self.train_dataset), - batch_size=batch_size, - drop_last=True, - batch_max_len=batch_max_len, lengths=get_dataset_lengths(self.train_dataset), - packing_efficiency_estimate=self.args.sample_packing_efficiency, + batch_max_len=batch_max_len, + batch_size=batch_size, + group_size=self.args.sample_packing_group_size, + bin_size=self.args.sample_packing_bin_size, ) if self.args.curriculum_sampling: return SequentialSampler(self.train_dataset) @@ -370,11 +378,11 @@ class AxolotlTrainer(Trainer): ) return MultipackBatchSampler( SequentialSampler(eval_dataset), - batch_size=batch_size, - drop_last=True, + lengths=get_dataset_lengths(self.eval_dataset), batch_max_len=batch_max_len, - lengths=get_dataset_lengths(eval_dataset), - packing_efficiency_estimate=self.args.sample_packing_efficiency, + batch_size=batch_size, + group_size=self.args.sample_packing_group_size, + bin_size=self.args.sample_packing_bin_size, ) return super()._get_eval_sampler(eval_dataset) @@ -1113,11 +1121,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.save_safetensors is not None: training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors - if self.cfg.sample_packing_eff_est: - training_arguments_kwargs[ - "sample_packing_efficiency" - ] = self.cfg.sample_packing_eff_est - if self.cfg.dataloader_pin_memory is not None: training_arguments_kwargs[ "dataloader_pin_memory" @@ -1293,20 +1296,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["weight_decay"] = ( self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 ) - training_arguments_kwargs["sample_packing"] = ( - self.cfg.sample_packing if self.cfg.sample_packing else False - ) - training_arguments_kwargs["multipack_real_batches"] = ( - self.cfg.flash_attention is not True - ) - training_arguments_kwargs["eval_sample_packing"] = ( - self.cfg.sample_packing - if self.cfg.eval_sample_packing is not False - else False - ) + + training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) training_arguments_kwargs[ - "sample_packing_seq_len_multiplier" - ] = self.cfg.micro_batch_size + "multipack_real_batches" + ] = not self.cfg.flash_attention + training_arguments_kwargs["eval_sample_packing"] = bool( + self.cfg.eval_sample_packing + ) + if self.cfg.sample_packing_bin_size is not None: + training_arguments_kwargs[ + "sample_packing_bin_size" + ] = self.cfg.sample_packing_bin_size + if self.cfg.sample_packing_group_size is not None: + training_arguments_kwargs[ + "sample_packing_group_size" + ] = self.cfg.sample_packing_group_size + if self.cfg.sample_packing_eff_est: + training_arguments_kwargs[ + "sample_packing_efficiency" + ] = self.cfg.sample_packing_eff_est + if self.cfg.relora_steps: training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs[ diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index b6eafa7a3..a14b66fa3 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -551,6 +551,8 @@ class AxolotlInputConfig( default=512, metadata={"help": "maximum prompt length for RL training"} ) sample_packing: Optional[bool] = None + sample_packing_group_size: Optional[int] = 100_000 + sample_packing_bin_size: Optional[int] = 200 eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None curriculum_sampling: Optional[bool] = None diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index 544ed1316..e056c7f50 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -150,6 +150,8 @@ def wrap_pretraining_dataset( max_seq_length=max_tokens, batch_size=batch_size, multipack_attn=cfg.pretrain_multipack_attn, + group_size=cfg.sample_packing_group_size, + bin_size=cfg.sample_packing_bin_size, ) # set this to 1 so downstream data_loader doesn't try to increase the batch again cfg.micro_batch_size = 1 @@ -189,6 +191,8 @@ def encode_packed_pretraining( max_seq_length: int = 2048, batch_size: int = 4, multipack_attn: Optional[bool] = False, + group_size: int = 100000, + bin_size: int = 200, ) -> Dict[str, List]: # pylint: disable=duplicate-code # tokenize all the examples @@ -202,11 +206,13 @@ def encode_packed_pretraining( ) sampler = MultipackBatchSampler( - RandomSampler(train_dataset), - batch_size=1, - drop_last=True, - batch_max_len=batch_size * max_seq_length, + sampler=RandomSampler(train_dataset), lengths=get_dataset_lengths(train_dataset), + batch_size=1, + batch_max_len=batch_size * max_seq_length, + group_size=group_size, + bin_size=bin_size, + drop_last=True, ) chunked_data = defaultdict(list) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index cf47d9639..07fd05682 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -1,105 +1,64 @@ -# pylint: skip-file """ Multipack Batch Sampler """ import logging -import math -import os -from typing import Any, Iterable, List, Union +from concurrent.futures import ProcessPoolExecutor +from multiprocessing import cpu_count import numba import numpy as np -from torch.utils.data import BatchSampler, Sampler +from torch.utils.data import BatchSampler LOG = logging.getLogger("axolotl.utils.samplers.multipack") +# First-fit-decreasing bin packing. @numba.njit -def ffd_check(a: np.ndarray, c: int, n: int): - # First-fit-decreasing bin packing - # Check if a[] could fit in n bins with capacity c - # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing +def pack_group(items, group_offset, bin_capacity, max_items_per_bin): + idxs = np.argsort(items)[::-1] + sorted_items = items[idxs] + num_bins = len(items) + bins = np.full(num_bins, bin_capacity, dtype=np.int32) + bin_counts = np.zeros(num_bins, dtype=np.int32) + group_packing = np.full((num_bins, max_items_per_bin), -1, dtype=np.int32) - a = np.sort(a)[::-1] - bins = np.full((n,), c, dtype=a.dtype) - for size in a: - not_found = True - for idx in range(n): - if bins[idx] >= size: - bins[idx] -= size - not_found = False + for idx, item in enumerate(sorted_items): + global_idx = idxs[idx] + group_offset + + placed = False + for i in range(num_bins): + if bins[i] >= item and bin_counts[i] < max_items_per_bin: + bins[i] -= item + group_packing[i, bin_counts[i]] = global_idx + bin_counts[i] += 1 + placed = True break - if not_found: - return False + if not placed: + raise ValueError( + f"Item could not be packed. Try increasing cfg.sample_packing_bin_size ({max_items_per_bin})." + ) - return True + return group_packing -@numba.njit -def ffd_with_result(a: np.ndarray, c: int, start_index: int): - # First-fit-decreasing bin packing (with result return) +def pack(items, bin_capacity, group_size, max_items_per_bin): + num_items = len(items) + num_processes = max(1, min(num_items // group_size, cpu_count())) + tasks = [ + (items[i : i + group_size], i, bin_capacity, max_items_per_bin) + for i in range(0, num_items, group_size) + ] - indices = np.argsort(a)[::-1] - a = a[indices] + packed_bins = [] + with ProcessPoolExecutor(max_workers=num_processes) as executor: + for group_packing in executor.map(pack_group, *zip(*tasks)): + for bin_pack in group_packing: + filtered_pack = bin_pack[bin_pack != -1] + if filtered_pack.size > 0: + packed_bins.append(filtered_pack.tolist()) - bins: List[Any] = [] - bins_result: List[Any] = [] - for a_id, size in enumerate(a): - add_new = True - for idx in range(len(bins)): - if bins[idx] >= size: - bins[idx] -= size - bins_result[idx].append(indices[a_id] + start_index) - add_new = False - break - - if add_new: - bins.append(c - size) - bins_result.append([indices[a_id] + start_index]) - - return bins_result - - -@numba.njit -def allocate( - lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int -): - # Dynamic batch allocator, similar to Multifit - # https://en.wikipedia.org/wiki/Multifit_algorithm - # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) - - s = 0 - start_index = 0 - result = [] - - while True: - # binary search [l, r) - left = 1 - right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") - - while right - left > 1: - mid = (left + right) // 2 - if ffd_check(lengths[start_index : start_index + mid], c, n): - left = mid - else: - right = mid - - # use length l - batch = ffd_with_result( - lengths[start_index : start_index + left], c, start_index - ) - assert len(batch) <= n - if len(batch) < n: - break - - start_index += left - s = lengths_cumsum[start_index - 1] - - # add local rank - result.append(batch[rank]) - - return result, s, len(result) * c * n + return packed_bins class MultipackBatchSampler(BatchSampler): @@ -109,94 +68,63 @@ class MultipackBatchSampler(BatchSampler): def __init__( self, - sampler: Union[Sampler[int], Iterable[int]], - batch_size: int, - drop_last: bool, - batch_max_len: int, - lengths: np.ndarray, - packing_efficiency_estimate: float = 1.0, + sampler, + lengths, + batch_max_len, + batch_size, + group_size=100_000, + bin_size=200, + drop_last=False, ): - super().__init__(sampler, batch_size, drop_last) - self.batch_size = batch_size + self.sampler = sampler + self.lengths = np.array(lengths, dtype=np.int32) self.batch_max_len = batch_max_len - self.lengths: np.ndarray = lengths - self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 + self.batch_size = batch_size + self.group_size = group_size + self.bin_size = bin_size + self.drop_last = drop_last - assert isinstance(self.lengths, np.ndarray) - - self.epoch = 0 - - # statistics - self.eff_total_used = 0 - self.eff_total_slots = 0 - - def set_epoch(self, epoch: int): - self.epoch = epoch - - def generate_batches(self, set_stats=False): - indices = [idx for idx in self.sampler] - - lengths = self.lengths[indices] - lengths_cumsum = np.cumsum(lengths) - - batches, total_used, total_slots = allocate( - lengths=lengths, - lengths_cumsum=lengths_cumsum, - rank=0, - c=self.batch_max_len, - n=1, - ) - - batches = [ - [ - [indices[b_idx] for b_idx in batch] - for batch in batches[i : i + self.batch_size] - ] - for i in range(0, len(batches), self.batch_size) - ] - - # statistics - if set_stats: - self.eff_total_used += total_used - self.eff_total_slots += total_slots - - return batches - - def __iter__(self): - batches = self.generate_batches(set_stats=True) - return iter(batches) - - def num_batches(self): - batches = self.generate_batches(set_stats=True) - return len(batches) + self._efficiency = None + self._batches = None def efficiency(self): - return self.eff_total_used / self.eff_total_slots + if self._efficiency is None: + self._batches = self._pack_batches() + return self._efficiency + + def _pack_batches(self): + # Get possibly shuffled indices from sampler. + sample_idxs = np.arange(len(self.sampler)) + lengths = self.lengths[sample_idxs] + + pack_idxs = pack( + lengths, + self.batch_max_len, + self.group_size, + self.bin_size, + ) + + used_tokens = self.lengths.sum() + available_tokens = len(pack_idxs) * self.batch_max_len + self._efficiency = used_tokens / available_tokens + + # Wrap packs into batches. + batch_idxs = [ + pack_idxs[i : i + self.batch_size] + for i in range(0, len(pack_idxs), self.batch_size) + ] + + # Drop last batch if needed. + if self.drop_last and len(batch_idxs[-1]) < self.batch_size: + batch_idxs = batch_idxs[:-1] + + return batch_idxs + + def __iter__(self): + self._batches = self._pack_batches() + return iter(self._batches) def __len__(self): - self.num_batches() - return self._len_est() - - def _len_est(self): - world_size = int(os.getenv("WORLD_SIZE", "1")) - lengths_sum = np.sum(self.lengths) - lengths_sum_per_device = lengths_sum // world_size - LOG.info( - f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " - f"total_num_tokens per device: {lengths_sum_per_device}" - ) - - # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler - return max( - 0, - ( - world_size - * math.floor( - 0.99 - * lengths_sum_per_device - / self.packing_efficiency_estimate - // (self.batch_max_len * self.batch_size) - ) - - 1 - ), - ) + if self._batches is None: + self._batches = self._pack_batches() + return len(self._batches) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 83977ef06..6760dc488 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -341,27 +341,26 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): ) else: if cfg.flash_attention: - batch_size = 1 + sampler_batch_size = 1 batch_max_len = cfg.micro_batch_size * cfg.sequence_len else: - batch_size = cfg.micro_batch_size + sampler_batch_size = cfg.micro_batch_size batch_max_len = cfg.sequence_len sampler = MultipackBatchSampler( sampler=RandomSampler(train_dataset), - batch_size=batch_size, - drop_last=True, - batch_max_len=batch_max_len, lengths=get_dataset_lengths(train_dataset), + batch_size=sampler_batch_size, + batch_max_len=batch_max_len, + group_size=cfg.sample_packing_group_size, + bin_size=cfg.sample_packing_bin_size, + drop_last=True, ) data_loader = DataLoader( train_dataset.remove_columns(["length"]), batch_sampler=sampler, ) - data_loader_len = len(data_loader) // ( - cfg.world_size * cfg.gradient_accumulation_steps - ) - actual_eff = sampler.efficiency() + data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True) # FIXME: is there a bug here somewhere? the total num steps depends # on the agreed on value for sample_packing_eff_est @@ -372,7 +371,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): return max(estimates) sample_packing_actual_eff_all = reduce_and_broadcast( - lambda: actual_eff, + lambda: sampler.efficiency(), # pylint: disable=unnecessary-lambda calc_sample_packing_eff_est, ) sample_packing_eff_est = ( diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index 50f39d60f..ceff11df9 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -62,12 +62,14 @@ class TestBatchedSamplerPacking: dataset, ) train_dataset = concatenate_datasets([dataset_wrapper]) + lengths = get_dataset_lengths(train_dataset) batch_sampler = MultipackBatchSampler( sampler=RandomSampler(train_dataset), + lengths=lengths, batch_size=batch_size, - drop_last=True, batch_max_len=max_seq_length, - lengths=get_dataset_lengths(train_dataset), + group_size=100000, + bin_size=200, ) loader = DataLoader( @@ -81,19 +83,15 @@ class TestBatchedSamplerPacking: ), num_workers=num_workers, ) - inputs = next(iter(loader)) - assert inputs["input_ids"].shape == (batch_size, max_seq_length) - assert inputs["labels"].shape == (batch_size, max_seq_length) - assert inputs["attention_mask"].shape == (batch_size, max_seq_length) + batch_idxs = [] + for batch in batch_sampler: + for pack in batch: + batch_idxs.extend(pack) - assert inputs["input_ids"].tolist()[0][0] == 2 - assert inputs["labels"].tolist()[0][0] == -100 - assert inputs["attention_mask"].tolist()[0][0] == 0 - assert inputs["attention_mask"].tolist()[0][-1] > 1 + for batch in loader: + assert len(batch["input_ids"]) <= batch_size * max_seq_length + assert batch["input_ids"].shape[1] == max_seq_length - if batch_size >= 2: - assert inputs["input_ids"].tolist()[1][0] == 2 - assert inputs["labels"].tolist()[1][0] == -100 - assert inputs["attention_mask"].tolist()[1][0] == 0 - assert inputs["attention_mask"].tolist()[1][-1] > 1 + original_idxs = set(range(len(train_dataset))) + assert original_idxs == set(batch_idxs) diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py index 528f9c807..fb623a43d 100644 --- a/tests/test_packed_pretraining.py +++ b/tests/test_packed_pretraining.py @@ -42,6 +42,8 @@ class TestPretrainingPacking(unittest.TestCase): "pad_to_sequence_len": True, "sequence_len": 2048, "micro_batch_size": 2, + "sample_packing_group_size": 100000, + "sample_packing_bin_size": 200, } )