From 468580d18efe5ec55c050e72b3fc3b30ad390641 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 12 Jun 2025 13:22:58 -0400 Subject: [PATCH] limit multipack sampler processes (#2771) [skip ci] * limit to 16 packing processes * make num_processes properly reflect configured dataset_processes --- src/axolotl/core/builders/base.py | 3 +++ src/axolotl/core/builders/rl.py | 4 ---- src/axolotl/core/trainers/base.py | 1 + src/axolotl/core/training_args.py | 4 ++++ src/axolotl/utils/samplers/multipack.py | 7 +++++-- src/axolotl/utils/trainer.py | 1 + 6 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 907de056b..ac49b4e88 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -490,6 +490,9 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1 training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs + if self.cfg.dataset_processes: + training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes + # max_length is not used in CausalTrainer if self.cfg.reward_model or self.cfg.rl: training_args_kwargs["max_length"] = self.cfg.sequence_len diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index 14dbfa715..80c5a9eef 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -90,10 +90,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): else: training_args_kwargs["remove_unused_columns"] = False - # only rlhf - if self.cfg.dataset_processes: - training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes - if self.cfg.trl and self.cfg.trl.beta is not None: training_args_kwargs["beta"] = self.cfg.trl.beta elif self.cfg.rl_beta is not None: diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index d6f2c579a..25ffb4cbf 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -114,6 +114,7 @@ class AxolotlTrainer( bin_size=self.args.sample_packing_bin_size, sequential=self.args.sample_packing_sequentially, drop_last=True, + num_processes=self.args.dataset_num_proc, ) def _get_train_sampler( diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 42488e643..2b53c6798 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -68,6 +68,10 @@ class AxolotlTrainingMixins: default=2048, metadata={"help": "The maximum sequence length the model can handle"}, ) + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "The number of processes to use for data processing"}, + ) relora_steps: Optional[int] = field( default=None, metadata={"help": "how often to reset for ReLoRA"}, diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index e488ed7d5..eabfc2d84 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -3,6 +3,7 @@ Multipack Batch Sampler - An efficient batch sampler for packing variable-length into fixed-capacity batches to optimize memory usage and training throughput. """ +import gc import math from concurrent.futures import ProcessPoolExecutor from multiprocessing import cpu_count, get_context @@ -145,7 +146,7 @@ def pack_parallel( """ num_items = len(sequence_lengths) if num_processes is None: - num_processes = max(1, min(num_items // group_size, cpu_count())) + num_processes = max(1, min(num_items // group_size, cpu_count(), 16)) # Create tasks for parallel processing tasks = [] @@ -259,7 +260,7 @@ class MultipackBatchSampler(BatchSampler): lengths: np.ndarray, # Sequence lengths packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate drop_last: bool = False, # Whether to drop final batches (might be incomplete) - num_count_samples: int = 16, # Number of times to estimate batch count + num_count_samples: int = 8, # Number of times to estimate batch count sequential: bool = False, # Whether to use sequential packing group_size: int = 100_000, # Size of groups for parallel packing bin_size: int = 200, # The max number of samples that can be packed in a single bin @@ -349,6 +350,7 @@ class MultipackBatchSampler(BatchSampler): # Calculate efficiency statistics total_used = lengths.sum() total_slots = len(all_bins) * self.batch_max_len + del all_bins # Group bins into batches (each batch contains batch_size bins) batches = [ @@ -368,6 +370,7 @@ class MultipackBatchSampler(BatchSampler): self.total_token_slots += total_slots self._batches = batches + gc.collect() return batches def __iter__(self) -> Iterator[list[list[int]]]: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 67f590a37..ec5360fa3 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -467,6 +467,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): bin_size=cfg.sample_packing_bin_size, sequential=cfg.sample_packing_sequentially, drop_last=True, + num_processes=cfg.dataset_processes, ) data_loader = DataLoader(