only calculate packed ds length once if using a large world size (#3210)

This commit is contained in:
Wing Lian
2025-10-09 14:18:46 -04:00
committed by GitHub
parent 3a5c97e6e5
commit 08b8fa62cc

View File

@@ -5,6 +5,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
import gc
import math
import os
import time
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count, get_context
@@ -291,7 +292,10 @@ class MultipackBatchSampler(BatchSampler):
self.total_token_slots = 0
# The number of times to calculate batches to determine minimum packed dataset length
self.num_count_samples = num_count_samples
world_size = int(os.environ.get("WORLD_SIZE", "1"))
self.num_count_samples = (
1 if world_size >= num_count_samples else num_count_samples
)
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warning(