support for explicit test_dataset definition for evals (#786)

This commit is contained in:
Wing Lian
2024-01-22 21:29:56 -05:00
committed by GitHub
parent e799e08d3c
commit cda52dc32b
2 changed files with 44 additions and 29 deletions

View File

@@ -519,6 +519,11 @@ def validate_config(cfg):
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
)
if cfg.test_datasets and cfg.val_set_size:
raise ValueError(
"non-zero val_set_size should not be used with test_datasets configuration"
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -4,7 +4,7 @@ import hashlib
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
from datasets import (
@@ -65,9 +65,17 @@ def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset:
with zero_first(is_main_process()):
train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
)
_, eval_dataset, _ = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test"
)
else:
train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
path = cfg.pretraining_dataset
name = None
@@ -108,8 +116,12 @@ def prepare_dataset(cfg, tokenizer):
def load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
tokenizer,
cfg,
default_dataset_prepared_path,
split="train",
) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = tokenizer.__class__.__name__
ds_hash = str(
md5(
@@ -126,7 +138,7 @@ def load_tokenized_prepared_datasets(
sorted(
[
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
for d in cfg.datasets
for d in cfg_datasets
]
)
)
@@ -149,7 +161,7 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
token=use_auth_token,
)
dataset = dataset["train"]
dataset = dataset[split]
except Exception: # pylint: disable=broad-except # nosec
pass
@@ -188,8 +200,8 @@ def load_tokenized_prepared_datasets(
yield dataset
# pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg.datasets):
ds: Union[Dataset, DatasetDict] = None
for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False
try:
load_dataset(
@@ -342,16 +354,6 @@ def load_tokenized_prepared_datasets(
)
if not ds:
raise ValueError("unhandled dataset load")
# support for using a subset of the data
if config_dataset.shards:
if "train" in ds:
ds = ds.shuffle(seed=seed)["train"].shard(
num_shards=config_dataset.shards, index=0
)
else:
ds = ds.shuffle(seed=seed).shard(
num_shards=config_dataset.shards, index=0
)
d_base_type = d_prompt_style = None
d_type = config_dataset.type
@@ -359,17 +361,21 @@ def load_tokenized_prepared_datasets(
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds:
ds = ds["train"]
elif (
isinstance(ds, DatasetDict)
and config_dataset.train_on_split
and config_dataset.train_on_split in ds
):
ds = ds[config_dataset.train_on_split]
if config_dataset.split and config_dataset.split in ds:
ds = ds[config_dataset.split]
elif split in ds:
ds = ds[split]
elif isinstance(ds, DatasetDict):
raise ValueError(
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
)
# support for using a subset of the data
if config_dataset.shards:
shards_idx = config_dataset.get("shards_idx", 0)
ds = ds.shuffle(seed=seed).shard(
num_shards=config_dataset.shards, index=shards_idx
)
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
@@ -428,6 +434,7 @@ def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase,
cfg,
default_dataset_prepared_path,
split="train",
) -> Tuple[Dataset, Dataset, List[Prompter]]:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
@@ -442,7 +449,7 @@ def load_prepare_datasets(
index=cfg.dataset_shard_idx,
)
if cfg.val_set_size:
if split == "train" and cfg.val_set_size:
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
dataset._fingerprint # pylint: disable=protected-access
@@ -475,6 +482,9 @@ def load_prepare_datasets(
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
elif split == "test":
train_dataset = None
eval_dataset = dataset
else:
train_dataset = dataset
eval_dataset = None