only calculate packed ds length once if using a large world size (#3210)
This commit is contained in:
@@ -5,6 +5,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
from multiprocessing import cpu_count, get_context
|
from multiprocessing import cpu_count, get_context
|
||||||
@@ -291,7 +292,10 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
self.total_token_slots = 0
|
self.total_token_slots = 0
|
||||||
|
|
||||||
# The number of times to calculate batches to determine minimum packed dataset length
|
# The number of times to calculate batches to determine minimum packed dataset length
|
||||||
self.num_count_samples = num_count_samples
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||||
|
self.num_count_samples = (
|
||||||
|
1 if world_size >= num_count_samples else num_count_samples
|
||||||
|
)
|
||||||
|
|
||||||
if self.sequential and not isinstance(sampler, SequentialSampler):
|
if self.sequential and not isinstance(sampler, SequentialSampler):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
|
|||||||
Reference in New Issue
Block a user