From 08b8fa62cc27d0c8bd7b8cb9bba91d6fcf9067ac Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 9 Oct 2025 14:18:46 -0400 Subject: [PATCH] only calculate packed ds length once if using a large world size (#3210) --- src/axolotl/utils/samplers/multipack.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index d07988613..662c63caa 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -5,6 +5,7 @@ into fixed-capacity batches to optimize memory usage and training throughput. import gc import math +import os import time from concurrent.futures import ProcessPoolExecutor from multiprocessing import cpu_count, get_context @@ -291,7 +292,10 @@ class MultipackBatchSampler(BatchSampler): self.total_token_slots = 0 # The number of times to calculate batches to determine minimum packed dataset length - self.num_count_samples = num_count_samples + world_size = int(os.environ.get("WORLD_SIZE", "1")) + self.num_count_samples = ( + 1 if world_size >= num_count_samples else num_count_samples + ) if self.sequential and not isinstance(sampler, SequentialSampler): LOG.warning(