prevent rate limiting to hf when using dispatch batches (#2536) [skip ci]

This commit is contained in:
Wing Lian
2025-04-21 10:31:35 -04:00
committed by GitHub
parent b882dfb63f
commit 341e95aac9
2 changed files with 23 additions and 3 deletions

View File

@@ -3,6 +3,7 @@
import functools
import logging
import os
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple, Union
@@ -117,9 +118,27 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
# when letting accelerator dispatch batches from the main process, we don't need to load the dataset from
# other ranks, we just need to present a fake dataset
if (
cfg.accelerator_config
and cfg.accelerator_config.dispatch_batches
and not is_local_main_process()
):
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
f.write("text\n")
f.write("lorem ipsum dolor sit amet\n")
# rewind the file pointer to the beginning so we can read it again
f.seek(0)
iter_ds = load_dataset(
"csv", data_files=f.name, split="train", streaming=True
)
else:
if is_local_main_process():
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip)

View File

@@ -660,6 +660,7 @@ class AxolotlInputConfig(
data.get("val_set_size") == 0
and (data.get("eval_steps") or data.get("eval_strategy"))
and not data.get("test_datasets")
and data.get("eval_strategy") != "no"
):
raise ValueError(
"eval_steps and eval_strategy are not supported with val_set_size == 0"