use fork for multiprocess start method for packing in parallel (#2830)

This commit is contained in:
Wing Lian
2025-06-25 13:17:33 -04:00
committed by GitHub
parent 181cc3106b
commit d8cf66edbd
5 changed files with 17 additions and 2 deletions

View File

@@ -116,6 +116,7 @@ class AxolotlTrainer(
sequential=self.args.sample_packing_sequentially,
drop_last=True,
num_processes=self.args.dataset_num_proc,
mp_start_method=self.args.sample_packing_mp_start_method or "fork",
)
len(sampler)

View File

@@ -38,6 +38,10 @@ class AxolotlTrainingMixins:
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
sample_packing_mp_start_method: str | None = field(
default=None,
metadata={"help": "The multiprocessing start method to use."},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},

View File

@@ -127,7 +127,7 @@ def pack_parallel(
bin_size: int,
num_processes: int | None = None,
safe_mode: bool = True,
mp_start_method: str | None = "spawn",
mp_start_method: str | None = "fork",
) -> list[list[int]]:
"""Pack sequences into bins using parallel processing.
@@ -266,6 +266,7 @@ class MultipackBatchSampler(BatchSampler):
bin_size: int = 200, # The max number of samples that can be packed in a single bin
num_processes: int | None = None, # Number of processes for parallel packing
safe_mode: bool = True, # Conservative packing to prevent training instability
mp_start_method: str = "fork",
**kwargs, # pylint: disable=unused-argument
):
super().__init__(sampler, batch_size, drop_last)
@@ -278,6 +279,7 @@ class MultipackBatchSampler(BatchSampler):
self.bin_size = bin_size
self.num_processes = num_processes
self.safe_mode = safe_mode
self.mp_start_method = mp_start_method
assert isinstance(self.lengths, np.ndarray)
@@ -338,8 +340,9 @@ class MultipackBatchSampler(BatchSampler):
bin_capacity=self.batch_max_len,
group_size=self.group_size,
bin_size=self.bin_size,
num_processes=self.num_processes,
num_processes=max(4, self.num_processes) if self.num_processes else 4,
safe_mode=self.safe_mode,
mp_start_method=self.mp_start_method,
)
# Map bin indices back to original indices

View File

@@ -393,6 +393,12 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={"description": "Whether to pack samples sequentially"},
)
sample_packing_mp_start_method: str | None = Field(
default=None,
json_schema_extra={
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
},
)
eval_sample_packing: bool | None = Field(
default=None,
json_schema_extra={

View File

@@ -467,6 +467,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
sequential=cfg.sample_packing_sequentially,
drop_last=True,
num_processes=cfg.dataset_processes,
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
)
data_loader = DataLoader(