adding test_datasets compat with pretraining_dataset (streaming) (#2206) [skip ci]
This commit is contained in:
@@ -85,6 +85,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
|||||||
processor=processor,
|
processor=processor,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Load streaming dataset if pretraining_dataset is given
|
||||||
path = cfg.pretraining_dataset
|
path = cfg.pretraining_dataset
|
||||||
split = "train"
|
split = "train"
|
||||||
name = None
|
name = None
|
||||||
@@ -116,7 +117,18 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
|||||||
)
|
)
|
||||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
|
|
||||||
|
# Load eval dataset (non-streaming) if specified
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
if cfg.test_datasets:
|
||||||
|
_, eval_dataset, _ = load_prepare_datasets(
|
||||||
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
DEFAULT_DATASET_PREPARED_PATH,
|
||||||
|
split="test",
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.dataset_exact_deduplication:
|
if cfg.dataset_exact_deduplication:
|
||||||
LOG.info("Deduplication not available for pretrained datasets")
|
LOG.info("Deduplication not available for pretrained datasets")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user