only calculate packed ds length once if using a large world size (#3210)
This commit is contained in:
@@ -5,6 +5,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
|
||||
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from multiprocessing import cpu_count, get_context
|
||||
@@ -291,7 +292,10 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.total_token_slots = 0
|
||||
|
||||
# The number of times to calculate batches to determine minimum packed dataset length
|
||||
self.num_count_samples = num_count_samples
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
self.num_count_samples = (
|
||||
1 if world_size >= num_count_samples else num_count_samples
|
||||
)
|
||||
|
||||
if self.sequential and not isinstance(sampler, SequentialSampler):
|
||||
LOG.warning(
|
||||
|
||||
Reference in New Issue
Block a user