prevent rate limiting to hf when using dispatch batches (#2536) [skip ci]
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -117,9 +118,27 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
|||||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||||
)
|
)
|
||||||
|
|
||||||
iter_ds = load_dataset(
|
# when letting accelerator dispatch batches from the main process, we don't need to load the dataset from
|
||||||
path, streaming=True, split=split, name=name, data_files=data_files
|
# other ranks, we just need to present a fake dataset
|
||||||
)
|
if (
|
||||||
|
cfg.accelerator_config
|
||||||
|
and cfg.accelerator_config.dispatch_batches
|
||||||
|
and not is_local_main_process()
|
||||||
|
):
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
|
||||||
|
f.write("text\n")
|
||||||
|
f.write("lorem ipsum dolor sit amet\n")
|
||||||
|
# rewind the file pointer to the beginning so we can read it again
|
||||||
|
f.seek(0)
|
||||||
|
iter_ds = load_dataset(
|
||||||
|
"csv", data_files=f.name, split="train", streaming=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if is_local_main_process():
|
||||||
|
iter_ds = load_dataset(
|
||||||
|
path, streaming=True, split=split, name=name, data_files=data_files
|
||||||
|
)
|
||||||
|
|
||||||
if skip:
|
if skip:
|
||||||
LOG.info(f"Skipping {skip} samples from the dataset")
|
LOG.info(f"Skipping {skip} samples from the dataset")
|
||||||
iter_ds = iter_ds.skip(skip)
|
iter_ds = iter_ds.skip(skip)
|
||||||
|
|||||||
@@ -660,6 +660,7 @@ class AxolotlInputConfig(
|
|||||||
data.get("val_set_size") == 0
|
data.get("val_set_size") == 0
|
||||||
and (data.get("eval_steps") or data.get("eval_strategy"))
|
and (data.get("eval_steps") or data.get("eval_strategy"))
|
||||||
and not data.get("test_datasets")
|
and not data.get("test_datasets")
|
||||||
|
and data.get("eval_strategy") != "no"
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
||||||
|
|||||||
Reference in New Issue
Block a user