support for explicit test_dataset definition for evals (#786)
This commit is contained in:
@@ -519,6 +519,11 @@ def validate_config(cfg):
|
|||||||
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.test_datasets and cfg.val_set_size:
|
||||||
|
raise ValueError(
|
||||||
|
"non-zero val_set_size should not be used with test_datasets configuration"
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import hashlib
|
|||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import (
|
from datasets import (
|
||||||
@@ -65,9 +65,17 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
prompters = []
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
if cfg.test_datasets:
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
train_dataset, _, prompters = load_prepare_datasets(
|
||||||
)
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
|
||||||
|
)
|
||||||
|
_, eval_dataset, _ = load_prepare_datasets(
|
||||||
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||||
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
path = cfg.pretraining_dataset
|
path = cfg.pretraining_dataset
|
||||||
name = None
|
name = None
|
||||||
@@ -108,8 +116,12 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
|
|
||||||
|
|
||||||
def load_tokenized_prepared_datasets(
|
def load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
default_dataset_prepared_path,
|
||||||
|
split="train",
|
||||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||||
|
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
||||||
tokenizer_name = tokenizer.__class__.__name__
|
tokenizer_name = tokenizer.__class__.__name__
|
||||||
ds_hash = str(
|
ds_hash = str(
|
||||||
md5(
|
md5(
|
||||||
@@ -126,7 +138,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
sorted(
|
sorted(
|
||||||
[
|
[
|
||||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
|
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
|
||||||
for d in cfg.datasets
|
for d in cfg_datasets
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -149,7 +161,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
dataset = dataset["train"]
|
dataset = dataset[split]
|
||||||
except Exception: # pylint: disable=broad-except # nosec
|
except Exception: # pylint: disable=broad-except # nosec
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -188,8 +200,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
yield dataset
|
yield dataset
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
for config_dataset in for_d_in_datasets(cfg.datasets):
|
for config_dataset in for_d_in_datasets(cfg_datasets):
|
||||||
ds: Union[Dataset, DatasetDict] = None
|
ds: Optional[Union[Dataset, DatasetDict]] = None
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
try:
|
try:
|
||||||
load_dataset(
|
load_dataset(
|
||||||
@@ -342,16 +354,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
)
|
)
|
||||||
if not ds:
|
if not ds:
|
||||||
raise ValueError("unhandled dataset load")
|
raise ValueError("unhandled dataset load")
|
||||||
# support for using a subset of the data
|
|
||||||
if config_dataset.shards:
|
|
||||||
if "train" in ds:
|
|
||||||
ds = ds.shuffle(seed=seed)["train"].shard(
|
|
||||||
num_shards=config_dataset.shards, index=0
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ds = ds.shuffle(seed=seed).shard(
|
|
||||||
num_shards=config_dataset.shards, index=0
|
|
||||||
)
|
|
||||||
|
|
||||||
d_base_type = d_prompt_style = None
|
d_base_type = d_prompt_style = None
|
||||||
d_type = config_dataset.type
|
d_type = config_dataset.type
|
||||||
@@ -359,17 +361,21 @@ def load_tokenized_prepared_datasets(
|
|||||||
d_type_split = d_type.split(":")
|
d_type_split = d_type.split(":")
|
||||||
d_base_type = d_type_split[0]
|
d_base_type = d_type_split[0]
|
||||||
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
||||||
if "train" in ds:
|
|
||||||
ds = ds["train"]
|
if config_dataset.split and config_dataset.split in ds:
|
||||||
elif (
|
ds = ds[config_dataset.split]
|
||||||
isinstance(ds, DatasetDict)
|
elif split in ds:
|
||||||
and config_dataset.train_on_split
|
ds = ds[split]
|
||||||
and config_dataset.train_on_split in ds
|
|
||||||
):
|
|
||||||
ds = ds[config_dataset.train_on_split]
|
|
||||||
elif isinstance(ds, DatasetDict):
|
elif isinstance(ds, DatasetDict):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
|
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
|
||||||
|
)
|
||||||
|
|
||||||
|
# support for using a subset of the data
|
||||||
|
if config_dataset.shards:
|
||||||
|
shards_idx = config_dataset.get("shards_idx", 0)
|
||||||
|
ds = ds.shuffle(seed=seed).shard(
|
||||||
|
num_shards=config_dataset.shards, index=shards_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
||||||
@@ -428,6 +434,7 @@ def load_prepare_datasets(
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
cfg,
|
cfg,
|
||||||
default_dataset_prepared_path,
|
default_dataset_prepared_path,
|
||||||
|
split="train",
|
||||||
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
||||||
dataset, prompters = load_tokenized_prepared_datasets(
|
dataset, prompters = load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
@@ -442,7 +449,7 @@ def load_prepare_datasets(
|
|||||||
index=cfg.dataset_shard_idx,
|
index=cfg.dataset_shard_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.val_set_size:
|
if split == "train" and cfg.val_set_size:
|
||||||
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
||||||
to_hash_train = (
|
to_hash_train = (
|
||||||
dataset._fingerprint # pylint: disable=protected-access
|
dataset._fingerprint # pylint: disable=protected-access
|
||||||
@@ -475,6 +482,9 @@ def load_prepare_datasets(
|
|||||||
|
|
||||||
train_dataset = dataset["train"]
|
train_dataset = dataset["train"]
|
||||||
eval_dataset = dataset["test"]
|
eval_dataset = dataset["test"]
|
||||||
|
elif split == "test":
|
||||||
|
train_dataset = None
|
||||||
|
eval_dataset = dataset
|
||||||
else:
|
else:
|
||||||
train_dataset = dataset
|
train_dataset = dataset
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
|||||||
Reference in New Issue
Block a user