fix bin size (#3307)

* fix bin size

* lint

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
This commit is contained in:
VED
2025-12-08 19:46:18 +05:30
committed by GitHub
parent 75b20fb66f
commit b3f4aa149f
2 changed files with 5 additions and 2 deletions

View File

@@ -203,6 +203,7 @@ def wrap_streaming_dataset(
max_seq_length=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
multipack_attn=multipack_attn,
bin_size=cfg.sample_packing_bin_size,
)
# Set this to 1 so downstream data_loader doesn't try to increase the batch size
@@ -254,6 +255,7 @@ def encode_packed_streaming(
collate_fn,
ds_wrapper: Callable,
examples: Dict[str, List],
bin_size: int,
max_seq_length: int = 2048,
batch_size: int = 4,
multipack_attn: Optional[bool] = True,
@@ -278,6 +280,7 @@ def encode_packed_streaming(
batch_max_len=batch_size * max_seq_length,
drop_last=True,
num_processes=1,
bin_size=bin_size,
)
chunked_data = defaultdict(list)

View File

@@ -260,12 +260,12 @@ class MultipackBatchSampler(BatchSampler):
batch_size: int, # Number of bins per batch
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths
bin_size: int, # The max number of samples that can be packed in a single bin
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
num_count_samples: int = 4, # Number of times to estimate batch count
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
bin_size: int = 200, # The max number of samples that can be packed in a single bin
num_processes: int | None = None, # Number of processes for parallel packing
safe_mode: bool = True, # Conservative packing to prevent training instability
mp_start_method: str = "fork",
@@ -343,7 +343,7 @@ class MultipackBatchSampler(BatchSampler):
lengths,
bin_capacity=self.batch_max_len,
group_size=self.group_size,
bin_size=self.bin_size,
bin_size=self.bin_size or self.batch_max_len,
num_processes=min(4, num_processes) if num_processes else 4,
safe_mode=self.safe_mode,
mp_start_method=self.mp_start_method,