use fork for multiprocess start method for packing in parallel (#2830)

This commit is contained in:
Wing Lian
2025-06-25 13:17:33 -04:00
committed by GitHub
parent 181cc3106b
commit d8cf66edbd
5 changed files with 17 additions and 2 deletions

View File

@@ -116,6 +116,7 @@ class AxolotlTrainer(
sequential=self.args.sample_packing_sequentially, sequential=self.args.sample_packing_sequentially,
drop_last=True, drop_last=True,
num_processes=self.args.dataset_num_proc, num_processes=self.args.dataset_num_proc,
mp_start_method=self.args.sample_packing_mp_start_method or "fork",
) )
len(sampler) len(sampler)

View File

@@ -38,6 +38,10 @@ class AxolotlTrainingMixins:
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing." "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
}, },
) )
sample_packing_mp_start_method: str | None = field(
default=None,
metadata={"help": "The multiprocessing start method to use."},
)
multipack_real_batches: bool = field( multipack_real_batches: bool = field(
default=False, default=False,
metadata={"help": "Use real batches for efficient training."}, metadata={"help": "Use real batches for efficient training."},

View File

@@ -127,7 +127,7 @@ def pack_parallel(
bin_size: int, bin_size: int,
num_processes: int | None = None, num_processes: int | None = None,
safe_mode: bool = True, safe_mode: bool = True,
mp_start_method: str | None = "spawn", mp_start_method: str | None = "fork",
) -> list[list[int]]: ) -> list[list[int]]:
"""Pack sequences into bins using parallel processing. """Pack sequences into bins using parallel processing.
@@ -266,6 +266,7 @@ class MultipackBatchSampler(BatchSampler):
bin_size: int = 200, # The max number of samples that can be packed in a single bin bin_size: int = 200, # The max number of samples that can be packed in a single bin
num_processes: int | None = None, # Number of processes for parallel packing num_processes: int | None = None, # Number of processes for parallel packing
safe_mode: bool = True, # Conservative packing to prevent training instability safe_mode: bool = True, # Conservative packing to prevent training instability
mp_start_method: str = "fork",
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
): ):
super().__init__(sampler, batch_size, drop_last) super().__init__(sampler, batch_size, drop_last)
@@ -278,6 +279,7 @@ class MultipackBatchSampler(BatchSampler):
self.bin_size = bin_size self.bin_size = bin_size
self.num_processes = num_processes self.num_processes = num_processes
self.safe_mode = safe_mode self.safe_mode = safe_mode
self.mp_start_method = mp_start_method
assert isinstance(self.lengths, np.ndarray) assert isinstance(self.lengths, np.ndarray)
@@ -338,8 +340,9 @@ class MultipackBatchSampler(BatchSampler):
bin_capacity=self.batch_max_len, bin_capacity=self.batch_max_len,
group_size=self.group_size, group_size=self.group_size,
bin_size=self.bin_size, bin_size=self.bin_size,
num_processes=self.num_processes, num_processes=max(4, self.num_processes) if self.num_processes else 4,
safe_mode=self.safe_mode, safe_mode=self.safe_mode,
mp_start_method=self.mp_start_method,
) )
# Map bin indices back to original indices # Map bin indices back to original indices

View File

@@ -393,6 +393,12 @@ class AxolotlInputConfig(
default=None, default=None,
json_schema_extra={"description": "Whether to pack samples sequentially"}, json_schema_extra={"description": "Whether to pack samples sequentially"},
) )
sample_packing_mp_start_method: str | None = Field(
default=None,
json_schema_extra={
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
},
)
eval_sample_packing: bool | None = Field( eval_sample_packing: bool | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={

View File

@@ -467,6 +467,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
sequential=cfg.sample_packing_sequentially, sequential=cfg.sample_packing_sequentially,
drop_last=True, drop_last=True,
num_processes=cfg.dataset_processes, num_processes=cfg.dataset_processes,
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
) )
data_loader = DataLoader( data_loader = DataLoader(