use fork for multiprocess start method for packing in parallel (#2830)
This commit is contained in:
@@ -116,6 +116,7 @@ class AxolotlTrainer(
|
|||||||
sequential=self.args.sample_packing_sequentially,
|
sequential=self.args.sample_packing_sequentially,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_processes=self.args.dataset_num_proc,
|
num_processes=self.args.dataset_num_proc,
|
||||||
|
mp_start_method=self.args.sample_packing_mp_start_method or "fork",
|
||||||
)
|
)
|
||||||
|
|
||||||
len(sampler)
|
len(sampler)
|
||||||
|
|||||||
@@ -38,6 +38,10 @@ class AxolotlTrainingMixins:
|
|||||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
sample_packing_mp_start_method: str | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The multiprocessing start method to use."},
|
||||||
|
)
|
||||||
multipack_real_batches: bool = field(
|
multipack_real_batches: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use real batches for efficient training."},
|
metadata={"help": "Use real batches for efficient training."},
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ def pack_parallel(
|
|||||||
bin_size: int,
|
bin_size: int,
|
||||||
num_processes: int | None = None,
|
num_processes: int | None = None,
|
||||||
safe_mode: bool = True,
|
safe_mode: bool = True,
|
||||||
mp_start_method: str | None = "spawn",
|
mp_start_method: str | None = "fork",
|
||||||
) -> list[list[int]]:
|
) -> list[list[int]]:
|
||||||
"""Pack sequences into bins using parallel processing.
|
"""Pack sequences into bins using parallel processing.
|
||||||
|
|
||||||
@@ -266,6 +266,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
||||||
num_processes: int | None = None, # Number of processes for parallel packing
|
num_processes: int | None = None, # Number of processes for parallel packing
|
||||||
safe_mode: bool = True, # Conservative packing to prevent training instability
|
safe_mode: bool = True, # Conservative packing to prevent training instability
|
||||||
|
mp_start_method: str = "fork",
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
super().__init__(sampler, batch_size, drop_last)
|
super().__init__(sampler, batch_size, drop_last)
|
||||||
@@ -278,6 +279,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
self.bin_size = bin_size
|
self.bin_size = bin_size
|
||||||
self.num_processes = num_processes
|
self.num_processes = num_processes
|
||||||
self.safe_mode = safe_mode
|
self.safe_mode = safe_mode
|
||||||
|
self.mp_start_method = mp_start_method
|
||||||
|
|
||||||
assert isinstance(self.lengths, np.ndarray)
|
assert isinstance(self.lengths, np.ndarray)
|
||||||
|
|
||||||
@@ -338,8 +340,9 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
bin_capacity=self.batch_max_len,
|
bin_capacity=self.batch_max_len,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
bin_size=self.bin_size,
|
bin_size=self.bin_size,
|
||||||
num_processes=self.num_processes,
|
num_processes=max(4, self.num_processes) if self.num_processes else 4,
|
||||||
safe_mode=self.safe_mode,
|
safe_mode=self.safe_mode,
|
||||||
|
mp_start_method=self.mp_start_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map bin indices back to original indices
|
# Map bin indices back to original indices
|
||||||
|
|||||||
@@ -393,6 +393,12 @@ class AxolotlInputConfig(
|
|||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Whether to pack samples sequentially"},
|
json_schema_extra={"description": "Whether to pack samples sequentially"},
|
||||||
)
|
)
|
||||||
|
sample_packing_mp_start_method: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
|
||||||
|
},
|
||||||
|
)
|
||||||
eval_sample_packing: bool | None = Field(
|
eval_sample_packing: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -467,6 +467,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
sequential=cfg.sample_packing_sequentially,
|
sequential=cfg.sample_packing_sequentially,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_processes=cfg.dataset_processes,
|
num_processes=cfg.dataset_processes,
|
||||||
|
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
|
||||||
)
|
)
|
||||||
|
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
|
|||||||
Reference in New Issue
Block a user