From 341e95aac9d95a06dbc821250e6081e771e5d649 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 21 Apr 2025 10:31:35 -0400 Subject: [PATCH] prevent rate limiting to hf when using dispatch batches (#2536) [skip ci] --- src/axolotl/utils/data/sft.py | 25 ++++++++++++++++++++++--- src/axolotl/utils/schemas/config.py | 1 + 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 35eab30c5..413f6d144 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -3,6 +3,7 @@ import functools import logging import os +import tempfile from pathlib import Path from typing import List, Optional, Tuple, Union @@ -117,9 +118,27 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): cfg.pretraining_dataset[0]["type"] or "pretrain", ) - iter_ds = load_dataset( - path, streaming=True, split=split, name=name, data_files=data_files - ) + # when letting accelerator dispatch batches from the main process, we don't need to load the dataset from + # other ranks, we just need to present a fake dataset + if ( + cfg.accelerator_config + and cfg.accelerator_config.dispatch_batches + and not is_local_main_process() + ): + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: + f.write("text\n") + f.write("lorem ipsum dolor sit amet\n") + # rewind the file pointer to the beginning so we can read it again + f.seek(0) + iter_ds = load_dataset( + "csv", data_files=f.name, split="train", streaming=True + ) + else: + if is_local_main_process(): + iter_ds = load_dataset( + path, streaming=True, split=split, name=name, data_files=data_files + ) + if skip: LOG.info(f"Skipping {skip} samples from the dataset") iter_ds = iter_ds.skip(skip) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index a0fc2c7d3..732ae60cf 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -660,6 +660,7 @@ class AxolotlInputConfig( data.get("val_set_size") == 0 and (data.get("eval_steps") or data.get("eval_strategy")) and not data.get("test_datasets") + and data.get("eval_strategy") != "no" ): raise ValueError( "eval_steps and eval_strategy are not supported with val_set_size == 0"