From b3f4aa149fa1d2a812c728b777e87822420ecde5 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Mon, 8 Dec 2025 19:46:18 +0530 Subject: [PATCH] fix bin size (#3307) * fix bin size * lint --------- Co-authored-by: Ved --- src/axolotl/utils/data/streaming.py | 3 +++ src/axolotl/utils/samplers/multipack.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/data/streaming.py b/src/axolotl/utils/data/streaming.py index 2cb35ee7c..8b6b8a439 100644 --- a/src/axolotl/utils/data/streaming.py +++ b/src/axolotl/utils/data/streaming.py @@ -203,6 +203,7 @@ def wrap_streaming_dataset( max_seq_length=cfg.sequence_len, batch_size=cfg.micro_batch_size, multipack_attn=multipack_attn, + bin_size=cfg.sample_packing_bin_size, ) # Set this to 1 so downstream data_loader doesn't try to increase the batch size @@ -254,6 +255,7 @@ def encode_packed_streaming( collate_fn, ds_wrapper: Callable, examples: Dict[str, List], + bin_size: int, max_seq_length: int = 2048, batch_size: int = 4, multipack_attn: Optional[bool] = True, @@ -278,6 +280,7 @@ def encode_packed_streaming( batch_max_len=batch_size * max_seq_length, drop_last=True, num_processes=1, + bin_size=bin_size, ) chunked_data = defaultdict(list) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 662c63caa..436a49c79 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -260,12 +260,12 @@ class MultipackBatchSampler(BatchSampler): batch_size: int, # Number of bins per batch batch_max_len: int, # Maximum sequence length (bin capacity) lengths: np.ndarray, # Sequence lengths + bin_size: int, # The max number of samples that can be packed in a single bin packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate drop_last: bool = True, # Whether to drop final batches (might be incomplete) num_count_samples: int = 4, # Number of times to estimate batch count sequential: bool = False, # Whether to use sequential packing group_size: int = 100_000, # Size of groups for parallel packing - bin_size: int = 200, # The max number of samples that can be packed in a single bin num_processes: int | None = None, # Number of processes for parallel packing safe_mode: bool = True, # Conservative packing to prevent training instability mp_start_method: str = "fork", @@ -343,7 +343,7 @@ class MultipackBatchSampler(BatchSampler): lengths, bin_capacity=self.batch_max_len, group_size=self.group_size, - bin_size=self.bin_size, + bin_size=self.bin_size or self.batch_max_len, num_processes=min(4, num_processes) if num_processes else 4, safe_mode=self.safe_mode, mp_start_method=self.mp_start_method,