use fork for multiprocess start method for packing in parallel (#2830)
This commit is contained in:
@@ -116,6 +116,7 @@ class AxolotlTrainer(
|
||||
sequential=self.args.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
num_processes=self.args.dataset_num_proc,
|
||||
mp_start_method=self.args.sample_packing_mp_start_method or "fork",
|
||||
)
|
||||
|
||||
len(sampler)
|
||||
|
||||
@@ -38,6 +38,10 @@ class AxolotlTrainingMixins:
|
||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||
},
|
||||
)
|
||||
sample_packing_mp_start_method: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The multiprocessing start method to use."},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
|
||||
@@ -127,7 +127,7 @@ def pack_parallel(
|
||||
bin_size: int,
|
||||
num_processes: int | None = None,
|
||||
safe_mode: bool = True,
|
||||
mp_start_method: str | None = "spawn",
|
||||
mp_start_method: str | None = "fork",
|
||||
) -> list[list[int]]:
|
||||
"""Pack sequences into bins using parallel processing.
|
||||
|
||||
@@ -266,6 +266,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
||||
num_processes: int | None = None, # Number of processes for parallel packing
|
||||
safe_mode: bool = True, # Conservative packing to prevent training instability
|
||||
mp_start_method: str = "fork",
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
@@ -278,6 +279,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.bin_size = bin_size
|
||||
self.num_processes = num_processes
|
||||
self.safe_mode = safe_mode
|
||||
self.mp_start_method = mp_start_method
|
||||
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
|
||||
@@ -338,8 +340,9 @@ class MultipackBatchSampler(BatchSampler):
|
||||
bin_capacity=self.batch_max_len,
|
||||
group_size=self.group_size,
|
||||
bin_size=self.bin_size,
|
||||
num_processes=self.num_processes,
|
||||
num_processes=max(4, self.num_processes) if self.num_processes else 4,
|
||||
safe_mode=self.safe_mode,
|
||||
mp_start_method=self.mp_start_method,
|
||||
)
|
||||
|
||||
# Map bin indices back to original indices
|
||||
|
||||
@@ -393,6 +393,12 @@ class AxolotlInputConfig(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Whether to pack samples sequentially"},
|
||||
)
|
||||
sample_packing_mp_start_method: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
|
||||
},
|
||||
)
|
||||
eval_sample_packing: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -467,6 +467,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
sequential=cfg.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
num_processes=cfg.dataset_processes,
|
||||
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
|
||||
)
|
||||
|
||||
data_loader = DataLoader(
|
||||
|
||||
Reference in New Issue
Block a user