support for explicit test_dataset definition for evals (#786)

This commit is contained in:
Wing Lian
2024-01-22 21:29:56 -05:00
committed by GitHub
parent e799e08d3c
commit cda52dc32b
2 changed files with 44 additions and 29 deletions

View File

@@ -519,6 +519,11 @@ def validate_config(cfg):
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention" "bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
) )
if cfg.test_datasets and cfg.val_set_size:
raise ValueError(
"non-zero val_set_size should not be used with test_datasets configuration"
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -4,7 +4,7 @@ import hashlib
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from datasets import ( from datasets import (
@@ -65,9 +65,17 @@ def prepare_dataset(cfg, tokenizer):
prompters = [] prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
with zero_first(is_main_process()): with zero_first(is_main_process()):
train_dataset, eval_dataset, prompters = load_prepare_datasets( if cfg.test_datasets:
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH train_dataset, _, prompters = load_prepare_datasets(
) tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
)
_, eval_dataset, _ = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test"
)
else:
train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else: else:
path = cfg.pretraining_dataset path = cfg.pretraining_dataset
name = None name = None
@@ -108,8 +116,12 @@ def prepare_dataset(cfg, tokenizer):
def load_tokenized_prepared_datasets( def load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer,
cfg,
default_dataset_prepared_path,
split="train",
) -> Tuple[DatasetDict, List[Prompter]]: ) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
ds_hash = str( ds_hash = str(
md5( md5(
@@ -126,7 +138,7 @@ def load_tokenized_prepared_datasets(
sorted( sorted(
[ [
f"{d.path}:{d.type}:{d.shards}:{d.conversation}" f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
for d in cfg.datasets for d in cfg_datasets
] ]
) )
) )
@@ -149,7 +161,7 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", f"{cfg.push_dataset_to_hub}/{ds_hash}",
token=use_auth_token, token=use_auth_token,
) )
dataset = dataset["train"] dataset = dataset[split]
except Exception: # pylint: disable=broad-except # nosec except Exception: # pylint: disable=broad-except # nosec
pass pass
@@ -188,8 +200,8 @@ def load_tokenized_prepared_datasets(
yield dataset yield dataset
# pylint: disable=invalid-name # pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg.datasets): for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Union[Dataset, DatasetDict] = None ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False ds_from_hub = False
try: try:
load_dataset( load_dataset(
@@ -342,16 +354,6 @@ def load_tokenized_prepared_datasets(
) )
if not ds: if not ds:
raise ValueError("unhandled dataset load") raise ValueError("unhandled dataset load")
# support for using a subset of the data
if config_dataset.shards:
if "train" in ds:
ds = ds.shuffle(seed=seed)["train"].shard(
num_shards=config_dataset.shards, index=0
)
else:
ds = ds.shuffle(seed=seed).shard(
num_shards=config_dataset.shards, index=0
)
d_base_type = d_prompt_style = None d_base_type = d_prompt_style = None
d_type = config_dataset.type d_type = config_dataset.type
@@ -359,17 +361,21 @@ def load_tokenized_prepared_datasets(
d_type_split = d_type.split(":") d_type_split = d_type.split(":")
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds:
ds = ds["train"] if config_dataset.split and config_dataset.split in ds:
elif ( ds = ds[config_dataset.split]
isinstance(ds, DatasetDict) elif split in ds:
and config_dataset.train_on_split ds = ds[split]
and config_dataset.train_on_split in ds
):
ds = ds[config_dataset.train_on_split]
elif isinstance(ds, DatasetDict): elif isinstance(ds, DatasetDict):
raise ValueError( raise ValueError(
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `" f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
)
# support for using a subset of the data
if config_dataset.shards:
shards_idx = config_dataset.get("shards_idx", 0)
ds = ds.shuffle(seed=seed).shard(
num_shards=config_dataset.shards, index=shards_idx
) )
dataset_wrapper, dataset_prompter = get_dataset_wrapper( dataset_wrapper, dataset_prompter = get_dataset_wrapper(
@@ -428,6 +434,7 @@ def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
cfg, cfg,
default_dataset_prepared_path, default_dataset_prepared_path,
split="train",
) -> Tuple[Dataset, Dataset, List[Prompter]]: ) -> Tuple[Dataset, Dataset, List[Prompter]]:
dataset, prompters = load_tokenized_prepared_datasets( dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
@@ -442,7 +449,7 @@ def load_prepare_datasets(
index=cfg.dataset_shard_idx, index=cfg.dataset_shard_idx,
) )
if cfg.val_set_size: if split == "train" and cfg.val_set_size:
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache # ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = ( to_hash_train = (
dataset._fingerprint # pylint: disable=protected-access dataset._fingerprint # pylint: disable=protected-access
@@ -475,6 +482,9 @@ def load_prepare_datasets(
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] eval_dataset = dataset["test"]
elif split == "test":
train_dataset = None
eval_dataset = dataset
else: else:
train_dataset = dataset train_dataset = dataset
eval_dataset = None eval_dataset = None