only calculate packed ds length once if using a large world size (#3210)

This commit is contained in:
Wing Lian
2025-10-09 14:18:46 -04:00
committed by GitHub
parent 3a5c97e6e5
commit 08b8fa62cc

View File

@@ -5,6 +5,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
import gc import gc
import math import math
import os
import time import time
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count, get_context from multiprocessing import cpu_count, get_context
@@ -291,7 +292,10 @@ class MultipackBatchSampler(BatchSampler):
self.total_token_slots = 0 self.total_token_slots = 0
# The number of times to calculate batches to determine minimum packed dataset length # The number of times to calculate batches to determine minimum packed dataset length
self.num_count_samples = num_count_samples world_size = int(os.environ.get("WORLD_SIZE", "1"))
self.num_count_samples = (
1 if world_size >= num_count_samples else num_count_samples
)
if self.sequential and not isinstance(sampler, SequentialSampler): if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warning( LOG.warning(