fix bin size (#3307)
* fix bin size * lint --------- Co-authored-by: Ved <ved.work2024@gmail.com>
This commit is contained in:
@@ -203,6 +203,7 @@ def wrap_streaming_dataset(
|
||||
max_seq_length=cfg.sequence_len,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
multipack_attn=multipack_attn,
|
||||
bin_size=cfg.sample_packing_bin_size,
|
||||
)
|
||||
|
||||
# Set this to 1 so downstream data_loader doesn't try to increase the batch size
|
||||
@@ -254,6 +255,7 @@ def encode_packed_streaming(
|
||||
collate_fn,
|
||||
ds_wrapper: Callable,
|
||||
examples: Dict[str, List],
|
||||
bin_size: int,
|
||||
max_seq_length: int = 2048,
|
||||
batch_size: int = 4,
|
||||
multipack_attn: Optional[bool] = True,
|
||||
@@ -278,6 +280,7 @@ def encode_packed_streaming(
|
||||
batch_max_len=batch_size * max_seq_length,
|
||||
drop_last=True,
|
||||
num_processes=1,
|
||||
bin_size=bin_size,
|
||||
)
|
||||
|
||||
chunked_data = defaultdict(list)
|
||||
|
||||
@@ -260,12 +260,12 @@ class MultipackBatchSampler(BatchSampler):
|
||||
batch_size: int, # Number of bins per batch
|
||||
batch_max_len: int, # Maximum sequence length (bin capacity)
|
||||
lengths: np.ndarray, # Sequence lengths
|
||||
bin_size: int, # The max number of samples that can be packed in a single bin
|
||||
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
||||
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
||||
num_count_samples: int = 4, # Number of times to estimate batch count
|
||||
sequential: bool = False, # Whether to use sequential packing
|
||||
group_size: int = 100_000, # Size of groups for parallel packing
|
||||
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
||||
num_processes: int | None = None, # Number of processes for parallel packing
|
||||
safe_mode: bool = True, # Conservative packing to prevent training instability
|
||||
mp_start_method: str = "fork",
|
||||
@@ -343,7 +343,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
lengths,
|
||||
bin_capacity=self.batch_max_len,
|
||||
group_size=self.group_size,
|
||||
bin_size=self.bin_size,
|
||||
bin_size=self.bin_size or self.batch_max_len,
|
||||
num_processes=min(4, num_processes) if num_processes else 4,
|
||||
safe_mode=self.safe_mode,
|
||||
mp_start_method=self.mp_start_method,
|
||||
|
||||
Reference in New Issue
Block a user