limit multipack sampler processes (#2771) [skip ci]

* limit to 16 packing processes

* make num_processes properly reflect configured dataset_processes
This commit is contained in:
Wing Lian
2025-06-12 13:22:58 -04:00
committed by GitHub
parent 3634d8ff9d
commit 468580d18e
6 changed files with 14 additions and 6 deletions

View File

@@ -490,6 +490,9 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
# max_length is not used in CausalTrainer
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len

View File

@@ -90,10 +90,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
training_args_kwargs["remove_unused_columns"] = False
# only rlhf
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta
elif self.cfg.rl_beta is not None:

View File

@@ -114,6 +114,7 @@ class AxolotlTrainer(
bin_size=self.args.sample_packing_bin_size,
sequential=self.args.sample_packing_sequentially,
drop_last=True,
num_processes=self.args.dataset_num_proc,
)
def _get_train_sampler(

View File

@@ -68,6 +68,10 @@ class AxolotlTrainingMixins:
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "The number of processes to use for data processing"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},

View File

@@ -3,6 +3,7 @@ Multipack Batch Sampler - An efficient batch sampler for packing variable-length
into fixed-capacity batches to optimize memory usage and training throughput.
"""
import gc
import math
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count, get_context
@@ -145,7 +146,7 @@ def pack_parallel(
"""
num_items = len(sequence_lengths)
if num_processes is None:
num_processes = max(1, min(num_items // group_size, cpu_count()))
num_processes = max(1, min(num_items // group_size, cpu_count(), 16))
# Create tasks for parallel processing
tasks = []
@@ -259,7 +260,7 @@ class MultipackBatchSampler(BatchSampler):
lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = False, # Whether to drop final batches (might be incomplete)
num_count_samples: int = 16, # Number of times to estimate batch count
num_count_samples: int = 8, # Number of times to estimate batch count
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
bin_size: int = 200, # The max number of samples that can be packed in a single bin
@@ -349,6 +350,7 @@ class MultipackBatchSampler(BatchSampler):
# Calculate efficiency statistics
total_used = lengths.sum()
total_slots = len(all_bins) * self.batch_max_len
del all_bins
# Group bins into batches (each batch contains batch_size bins)
batches = [
@@ -368,6 +370,7 @@ class MultipackBatchSampler(BatchSampler):
self.total_token_slots += total_slots
self._batches = batches
gc.collect()
return batches
def __iter__(self) -> Iterator[list[list[int]]]:

View File

@@ -467,6 +467,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
bin_size=cfg.sample_packing_bin_size,
sequential=cfg.sample_packing_sequentially,
drop_last=True,
num_processes=cfg.dataset_processes,
)
data_loader = DataLoader(