diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 3864903a5..ab9735adc 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -114,6 +114,8 @@ class AxolotlTrainer( packing_efficiency_estimate=self.args.sample_packing_efficiency, batch_max_len=batch_max_len, batch_size=batch_size, + group_size=self.args.sample_packing_group_size, + bin_size=self.args.sample_packing_bin_size, sequential=self.args.sample_packing_sequentially, drop_last=True, ) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 1abe594e3..394470b4b 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -61,7 +61,8 @@ def pack_group( group_offset: int, bin_capacity: int, max_bins: int, - safe_mode: bool = False, + bin_size: int, + safe_mode: bool = True, ): """ Pack a group of sequences into bins using First-Fit Decreasing algorithm @@ -71,6 +72,7 @@ def pack_group( group_offset: Offset to apply to indices when returning results bin_capacity: Maximum capacity of each bin max_bins: Maximum number of bins to use + bin_size: Maximum number of sequences per bin safe_mode: If True, use a more conservative packing approach Returns: @@ -89,7 +91,10 @@ def pack_group( # Try to place sequence in existing bins add_new_bin = True for bin_idx, _ in enumerate(bins_remaining_space): - if bins_remaining_space[bin_idx] >= size: + if ( + bins_remaining_space[bin_idx] >= size + and len(bins_assigned_sequences[bin_idx]) < bin_size + ): bins_remaining_space[bin_idx] -= size bins_assigned_sequences[bin_idx].append(global_idx) add_new_bin = False @@ -112,14 +117,17 @@ def pack_group( # Define a standalone function for multiprocessing def _process_group(args): - group_lengths, start_idx, bin_capacity, max_bins, safe_mode = args - return pack_group(group_lengths, start_idx, bin_capacity, max_bins, safe_mode) + group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode = args + return pack_group( + group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode + ) def pack_parallel( sequence_lengths: np.ndarray, bin_capacity: int, group_size: int, + bin_size: int, num_processes: int | None = None, safe_mode: bool = True, ): @@ -128,8 +136,9 @@ def pack_parallel( Args: sequence_lengths: Array of sequence lengths - bin_capacity: Maximum capacity of each bin + bin_capacity: Maximum capacity of each bin as total number of tokens group_size: Number of sequences to process in each group + bin_size: Maximum number of bins to use num_processes: Number of parallel processes to use safe_mode: If True, use a more conservative packing approach @@ -145,7 +154,7 @@ def pack_parallel( for i in range(0, num_items, group_size): group_lengths = sequence_lengths[i : i + group_size] max_bins = len(group_lengths) # Allow as many bins as items in the group - tasks.append((group_lengths, i, bin_capacity, max_bins, safe_mode)) + tasks.append((group_lengths, i, bin_capacity, max_bins, bin_size, safe_mode)) # Process groups in parallel all_bins = [] @@ -230,6 +239,7 @@ class MultipackBatchSampler(BatchSampler): num_count_samples: int = 16, # Number of samples to estimate batch count sequential: bool = False, # Whether to use sequential packing group_size: int = 100_000, # Size of groups for parallel packing + bin_size: int = 200, # The max number of samples that can be packed in a single bin num_processes: int | None = None, # Number of processes for parallel packing safe_mode: bool = True, # Conservative packing to prevent training instability **kwargs, # pylint: disable=unused-argument @@ -241,6 +251,7 @@ class MultipackBatchSampler(BatchSampler): self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.sequential = sequential self.group_size = group_size + self.bin_size = bin_size self.num_processes = num_processes self.safe_mode = safe_mode @@ -261,7 +272,7 @@ class MultipackBatchSampler(BatchSampler): self._batches = None if self.sequential and not isinstance(sampler, SequentialSampler): - LOG.warn( + LOG.warning( "using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?" ) @@ -306,6 +317,7 @@ class MultipackBatchSampler(BatchSampler): lengths, bin_capacity=self.batch_max_len, group_size=self.group_size, + bin_size=self.bin_size, num_processes=self.num_processes, safe_mode=self.safe_mode, )