From 9f1d548534b933d2abc66369ce87a572958c954a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 May 2025 10:38:32 -0400 Subject: [PATCH] don't use zero first context for loading datasets --- src/axolotl/utils/data/sft.py | 55 +++++++++++++++++------------------ 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 6de2d2cf7..6f5931ea7 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -53,7 +53,7 @@ from axolotl.utils.data.utils import ( retry_on_request_exceptions, ) from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_local_main_process, zero_first +from axolotl.utils.distributed import is_local_main_process from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, @@ -66,32 +66,31 @@ LOG = logging.getLogger(__name__) def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): prompters = [] if not cfg.pretraining_dataset: - with zero_first(is_local_main_process()): - if cfg.test_datasets: - train_dataset, _, prompters = load_prepare_datasets( - tokenizer, - cfg, - DEFAULT_DATASET_PREPARED_PATH, - split="train", - processor=processor, - preprocess_iterable=preprocess_iterable, - ) - _, eval_dataset, _ = load_prepare_datasets( - tokenizer, - cfg, - DEFAULT_DATASET_PREPARED_PATH, - split="test", - processor=processor, - preprocess_iterable=preprocess_iterable, - ) - else: - train_dataset, eval_dataset, prompters = load_prepare_datasets( - tokenizer, - cfg, - DEFAULT_DATASET_PREPARED_PATH, - processor=processor, - preprocess_iterable=preprocess_iterable, - ) + if cfg.test_datasets: + train_dataset, _, prompters = load_prepare_datasets( + tokenizer, + cfg, + DEFAULT_DATASET_PREPARED_PATH, + split="train", + processor=processor, + preprocess_iterable=preprocess_iterable, + ) + _, eval_dataset, _ = load_prepare_datasets( + tokenizer, + cfg, + DEFAULT_DATASET_PREPARED_PATH, + split="test", + processor=processor, + preprocess_iterable=preprocess_iterable, + ) + else: + train_dataset, eval_dataset, prompters = load_prepare_datasets( + tokenizer, + cfg, + DEFAULT_DATASET_PREPARED_PATH, + processor=processor, + preprocess_iterable=preprocess_iterable, + ) else: # Load streaming dataset if pretraining_dataset is given path = cfg.pretraining_dataset @@ -272,7 +271,7 @@ def load_tokenized_prepared_datasets( LOG.info("Loading raw datasets...") if not cfg.is_preprocess: LOG.warning( - "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset." + "Processing datasets during training can lead to VRAM instability. Please use `axolotl preprocess` to prepare your dataset." ) if cfg.seed: