225 lines
7.8 KiB
Python
225 lines
7.8 KiB
Python
"""
|
|
dataset loading shared utils
|
|
"""
|
|
from pathlib import Path
|
|
from typing import Optional, Union
|
|
|
|
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
|
from huggingface_hub import hf_hub_download
|
|
from huggingface_hub.errors import HFValidationError
|
|
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
|
|
def get_ds_type(config_dataset: DictDefault):
|
|
"""
|
|
Get the dataset type from the path if it's not specified
|
|
"""
|
|
ds_type = "json"
|
|
if config_dataset.ds_type:
|
|
ds_type = config_dataset.ds_type
|
|
elif ".parquet" in config_dataset.path:
|
|
ds_type = "parquet"
|
|
elif ".arrow" in config_dataset.path:
|
|
ds_type = "arrow"
|
|
elif ".csv" in config_dataset.path:
|
|
ds_type = "csv"
|
|
elif ".txt" in config_dataset.path:
|
|
ds_type = "text"
|
|
return ds_type
|
|
|
|
|
|
def load_dataset_w_config(
|
|
config_dataset, auth_token, streaming=False
|
|
) -> Union[Dataset, DatasetDict]:
|
|
# pylint: disable=invalid-name
|
|
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
|
ds_from_hub = False
|
|
ds_trust_remote_code = config_dataset.trust_remote_code
|
|
try:
|
|
# this is just a basic check to see if the path is a
|
|
# valid HF dataset that's loadable
|
|
load_dataset(
|
|
config_dataset.path,
|
|
name=config_dataset.name,
|
|
streaming=True,
|
|
token=auth_token,
|
|
revision=config_dataset.revision,
|
|
trust_remote_code=ds_trust_remote_code,
|
|
)
|
|
ds_from_hub = True
|
|
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
|
pass
|
|
|
|
ds_from_cloud = False
|
|
storage_options = {}
|
|
remote_file_system = None
|
|
if config_dataset.path.startswith("s3://"):
|
|
try:
|
|
import aiobotocore.session # type: ignore
|
|
import s3fs # type: ignore
|
|
except ImportError as exc:
|
|
raise ImportError(
|
|
"s3:// paths require aiobotocore and s3fs to be installed"
|
|
) from exc
|
|
|
|
# Takes credentials from ~/.aws/credentials for default profile
|
|
s3_session = aiobotocore.session.AioSession(profile="default")
|
|
storage_options = {"session": s3_session}
|
|
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
|
elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith(
|
|
"gcs://"
|
|
):
|
|
try:
|
|
import gcsfs # type: ignore
|
|
except ImportError as exc:
|
|
raise ImportError(
|
|
"gs:// or gcs:// paths require gcsfs to be installed"
|
|
) from exc
|
|
|
|
# gcsfs will use default credentials from the environment else anon
|
|
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
|
storage_options = {"token": None}
|
|
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
|
# TODO: Figure out how to get auth creds passed
|
|
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
|
|
# try:
|
|
# import adlfs
|
|
# except ImportError as exc:
|
|
# raise ImportError(
|
|
# "adl:// or abfs:// paths require adlfs to be installed"
|
|
# ) from exc
|
|
|
|
# # Gen 1
|
|
# storage_options = {
|
|
# "tenant_id": TENANT_ID,
|
|
# "client_id": CLIENT_ID,
|
|
# "client_secret": CLIENT_SECRET,
|
|
# }
|
|
# # Gen 2
|
|
# storage_options = {
|
|
# "account_name": ACCOUNT_NAME,
|
|
# "account_key": ACCOUNT_KEY,
|
|
# }
|
|
|
|
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
|
try:
|
|
if remote_file_system and remote_file_system.exists(config_dataset.path):
|
|
ds_from_cloud = True
|
|
except (FileNotFoundError, ConnectionError):
|
|
pass
|
|
|
|
# prefer local dataset, even if hub exists
|
|
local_path = Path(config_dataset.path)
|
|
if local_path.exists():
|
|
if local_path.is_dir():
|
|
if config_dataset.data_files:
|
|
ds_type = get_ds_type(config_dataset)
|
|
ds = load_dataset( # pylint: disable=invalid-name
|
|
ds_type,
|
|
name=config_dataset.name,
|
|
data_files=config_dataset.data_files,
|
|
streaming=streaming,
|
|
split=None,
|
|
)
|
|
else:
|
|
try:
|
|
ds = load_from_disk(
|
|
config_dataset.path
|
|
) # pylint: disable=invalid-name
|
|
except FileNotFoundError:
|
|
ds = load_dataset(
|
|
config_dataset.path,
|
|
name=config_dataset.name,
|
|
streaming=False,
|
|
split=None,
|
|
)
|
|
elif local_path.is_file():
|
|
ds_type = get_ds_type(config_dataset)
|
|
|
|
ds = load_dataset( # pylint: disable=invalid-name
|
|
ds_type,
|
|
name=config_dataset.name,
|
|
data_files=config_dataset.path,
|
|
streaming=False,
|
|
split=None,
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
|
)
|
|
elif ds_from_hub:
|
|
load_ds_kwargs = {}
|
|
if config_dataset.split:
|
|
load_ds_kwargs["split"] = config_dataset.split
|
|
ds = load_dataset(
|
|
config_dataset.path,
|
|
name=config_dataset.name,
|
|
streaming=streaming,
|
|
data_files=config_dataset.data_files,
|
|
token=auth_token,
|
|
revision=config_dataset.revision,
|
|
trust_remote_code=config_dataset.trust_remote_code,
|
|
**load_ds_kwargs,
|
|
)
|
|
elif ds_from_cloud and remote_file_system:
|
|
if remote_file_system.isdir(config_dataset.path):
|
|
ds = load_from_disk(
|
|
config_dataset.path,
|
|
storage_options=storage_options,
|
|
)
|
|
elif remote_file_system.isfile(config_dataset.path):
|
|
ds_type = get_ds_type(config_dataset)
|
|
ds = load_dataset(
|
|
ds_type,
|
|
name=config_dataset.name,
|
|
data_files=config_dataset.path,
|
|
streaming=streaming,
|
|
split=None,
|
|
storage_options=storage_options,
|
|
trust_remote_code=config_dataset.trust_remote_code,
|
|
)
|
|
elif config_dataset.path.startswith("https://"):
|
|
ds_type = get_ds_type(config_dataset)
|
|
ds = load_dataset(
|
|
ds_type,
|
|
name=config_dataset.name,
|
|
data_files=config_dataset.path,
|
|
streaming=streaming,
|
|
split=None,
|
|
storage_options=storage_options,
|
|
trust_remote_code=config_dataset.trust_remote_code,
|
|
)
|
|
else:
|
|
if isinstance(config_dataset.data_files, str):
|
|
fp = hf_hub_download(
|
|
repo_id=config_dataset.path,
|
|
repo_type="dataset",
|
|
filename=config_dataset.data_files,
|
|
revision=config_dataset.revision,
|
|
)
|
|
elif isinstance(config_dataset.data_files, list):
|
|
fp = []
|
|
for file in config_dataset.data_files:
|
|
fp.append(
|
|
hf_hub_download(
|
|
repo_id=config_dataset.path,
|
|
repo_type="dataset",
|
|
filename=file,
|
|
revision=config_dataset.revision,
|
|
)
|
|
)
|
|
else:
|
|
raise ValueError("data_files must be either a string or list of strings")
|
|
ds = load_dataset(
|
|
"json",
|
|
name=config_dataset.name,
|
|
data_files=fp,
|
|
streaming=streaming,
|
|
split=None,
|
|
)
|
|
if not ds:
|
|
raise ValueError("unhandled dataset load")
|
|
|
|
return ds
|