From d8cf66edbd33239bb93cd020dcba1e45ff4073be Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 25 Jun 2025 13:17:33 -0400 Subject: [PATCH] use fork for multiprocess start method for packing in parallel (#2830) --- src/axolotl/core/trainers/base.py | 1 + src/axolotl/core/training_args_base.py | 4 ++++ src/axolotl/utils/samplers/multipack.py | 7 +++++-- src/axolotl/utils/schemas/config.py | 6 ++++++ src/axolotl/utils/trainer.py | 1 + 5 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 3e9ea7ae8..b0e6e8eae 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -116,6 +116,7 @@ class AxolotlTrainer( sequential=self.args.sample_packing_sequentially, drop_last=True, num_processes=self.args.dataset_num_proc, + mp_start_method=self.args.sample_packing_mp_start_method or "fork", ) len(sampler) diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py index 8fcaff632..e04be43e0 100644 --- a/src/axolotl/core/training_args_base.py +++ b/src/axolotl/core/training_args_base.py @@ -38,6 +38,10 @@ class AxolotlTrainingMixins: "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing." }, ) + sample_packing_mp_start_method: str | None = field( + default=None, + metadata={"help": "The multiprocessing start method to use."}, + ) multipack_real_batches: bool = field( default=False, metadata={"help": "Use real batches for efficient training."}, diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 7fb5e1b41..95d97e7a0 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -127,7 +127,7 @@ def pack_parallel( bin_size: int, num_processes: int | None = None, safe_mode: bool = True, - mp_start_method: str | None = "spawn", + mp_start_method: str | None = "fork", ) -> list[list[int]]: """Pack sequences into bins using parallel processing. @@ -266,6 +266,7 @@ class MultipackBatchSampler(BatchSampler): bin_size: int = 200, # The max number of samples that can be packed in a single bin num_processes: int | None = None, # Number of processes for parallel packing safe_mode: bool = True, # Conservative packing to prevent training instability + mp_start_method: str = "fork", **kwargs, # pylint: disable=unused-argument ): super().__init__(sampler, batch_size, drop_last) @@ -278,6 +279,7 @@ class MultipackBatchSampler(BatchSampler): self.bin_size = bin_size self.num_processes = num_processes self.safe_mode = safe_mode + self.mp_start_method = mp_start_method assert isinstance(self.lengths, np.ndarray) @@ -338,8 +340,9 @@ class MultipackBatchSampler(BatchSampler): bin_capacity=self.batch_max_len, group_size=self.group_size, bin_size=self.bin_size, - num_processes=self.num_processes, + num_processes=max(4, self.num_processes) if self.num_processes else 4, safe_mode=self.safe_mode, + mp_start_method=self.mp_start_method, ) # Map bin indices back to original indices diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index c698fc3b6..4031742cd 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -393,6 +393,12 @@ class AxolotlInputConfig( default=None, json_schema_extra={"description": "Whether to pack samples sequentially"}, ) + sample_packing_mp_start_method: str | None = Field( + default=None, + json_schema_extra={ + "description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'" + }, + ) eval_sample_packing: bool | None = Field( default=None, json_schema_extra={ diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 633dffde5..554a55abc 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -467,6 +467,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): sequential=cfg.sample_packing_sequentially, drop_last=True, num_processes=cfg.dataset_processes, + mp_start_method=cfg.sample_packing_mp_start_method or "fork", ) data_loader = DataLoader(