Compare commits
1 Commits
no-zero-ds
...
axolotl-ci
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f9e5e22e6b |
@@ -53,7 +53,7 @@ from axolotl.utils.data.utils import (
|
||||
retry_on_request_exceptions,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_local_main_process
|
||||
from axolotl.utils.distributed import is_local_main_process, zero_first
|
||||
from axolotl.utils.trainer import (
|
||||
calculate_total_num_steps,
|
||||
process_datasets_for_packing,
|
||||
@@ -66,31 +66,32 @@ LOG = logging.getLogger(__name__)
|
||||
def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
||||
prompters = []
|
||||
if not cfg.pretraining_dataset:
|
||||
if cfg.test_datasets:
|
||||
train_dataset, _, prompters = load_prepare_datasets(
|
||||
tokenizer,
|
||||
cfg,
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
split="train",
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
_, eval_dataset, _ = load_prepare_datasets(
|
||||
tokenizer,
|
||||
cfg,
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
split="test",
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
else:
|
||||
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||
tokenizer,
|
||||
cfg,
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
with zero_first(is_local_main_process()):
|
||||
if cfg.test_datasets:
|
||||
train_dataset, _, prompters = load_prepare_datasets(
|
||||
tokenizer,
|
||||
cfg,
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
split="train",
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
_, eval_dataset, _ = load_prepare_datasets(
|
||||
tokenizer,
|
||||
cfg,
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
split="test",
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
else:
|
||||
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||
tokenizer,
|
||||
cfg,
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
else:
|
||||
# Load streaming dataset if pretraining_dataset is given
|
||||
path = cfg.pretraining_dataset
|
||||
@@ -271,7 +272,7 @@ def load_tokenized_prepared_datasets(
|
||||
LOG.info("Loading raw datasets...")
|
||||
if not cfg.is_preprocess:
|
||||
LOG.warning(
|
||||
"Processing datasets during training can lead to VRAM instability. Please use `axolotl preprocess` to prepare your dataset."
|
||||
"Processing datasets during training can lead to VRAM instability. Please pre-process your dataset."
|
||||
)
|
||||
|
||||
if cfg.seed:
|
||||
|
||||
@@ -58,11 +58,15 @@ def snapshot_download_w_retry(*args, **kwargs):
|
||||
"""
|
||||
with hf_offline_context(True):
|
||||
try:
|
||||
return snapshot_download(*args, **kwargs)
|
||||
return snapshot_download(
|
||||
*args, user_agent={"is_ci": "true", "axolotl": "ci"}, **kwargs
|
||||
)
|
||||
except LocalEntryNotFoundError:
|
||||
pass
|
||||
with hf_offline_context(False):
|
||||
return snapshot_download(*args, **kwargs)
|
||||
return snapshot_download(
|
||||
*args, user_agent={"is_ci": "true", "axolotl": "ci"}, **kwargs
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
|
||||
Reference in New Issue
Block a user