limit multipack sampler processes (#2771) [skip ci]
* limit to 16 packing processes * make num_processes properly reflect configured dataset_processes
This commit is contained in:
@@ -490,6 +490,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
# max_length is not used in CausalTrainer
|
||||
if self.cfg.reward_model or self.cfg.rl:
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
@@ -90,10 +90,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
training_args_kwargs["remove_unused_columns"] = False
|
||||
|
||||
# only rlhf
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
if self.cfg.trl and self.cfg.trl.beta is not None:
|
||||
training_args_kwargs["beta"] = self.cfg.trl.beta
|
||||
elif self.cfg.rl_beta is not None:
|
||||
|
||||
@@ -114,6 +114,7 @@ class AxolotlTrainer(
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
sequential=self.args.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
num_processes=self.args.dataset_num_proc,
|
||||
)
|
||||
|
||||
def _get_train_sampler(
|
||||
|
||||
@@ -68,6 +68,10 @@ class AxolotlTrainingMixins:
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for data processing"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
|
||||
@@ -3,6 +3,7 @@ Multipack Batch Sampler - An efficient batch sampler for packing variable-length
|
||||
into fixed-capacity batches to optimize memory usage and training throughput.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import math
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from multiprocessing import cpu_count, get_context
|
||||
@@ -145,7 +146,7 @@ def pack_parallel(
|
||||
"""
|
||||
num_items = len(sequence_lengths)
|
||||
if num_processes is None:
|
||||
num_processes = max(1, min(num_items // group_size, cpu_count()))
|
||||
num_processes = max(1, min(num_items // group_size, cpu_count(), 16))
|
||||
|
||||
# Create tasks for parallel processing
|
||||
tasks = []
|
||||
@@ -259,7 +260,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
lengths: np.ndarray, # Sequence lengths
|
||||
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
||||
drop_last: bool = False, # Whether to drop final batches (might be incomplete)
|
||||
num_count_samples: int = 16, # Number of times to estimate batch count
|
||||
num_count_samples: int = 8, # Number of times to estimate batch count
|
||||
sequential: bool = False, # Whether to use sequential packing
|
||||
group_size: int = 100_000, # Size of groups for parallel packing
|
||||
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
||||
@@ -349,6 +350,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
# Calculate efficiency statistics
|
||||
total_used = lengths.sum()
|
||||
total_slots = len(all_bins) * self.batch_max_len
|
||||
del all_bins
|
||||
|
||||
# Group bins into batches (each batch contains batch_size bins)
|
||||
batches = [
|
||||
@@ -368,6 +370,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.total_token_slots += total_slots
|
||||
|
||||
self._batches = batches
|
||||
gc.collect()
|
||||
return batches
|
||||
|
||||
def __iter__(self) -> Iterator[list[list[int]]]:
|
||||
|
||||
@@ -467,6 +467,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
bin_size=cfg.sample_packing_bin_size,
|
||||
sequential=cfg.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
num_processes=cfg.dataset_processes,
|
||||
)
|
||||
|
||||
data_loader = DataLoader(
|
||||
|
||||
Reference in New Issue
Block a user