Take split param from config in all load_dataset instances (#2281)

This commit is contained in:
mashdragon
2025-01-24 15:06:50 +00:00
committed by GitHub
parent 74f9782fc3
commit b2774af66c

View File

@@ -107,6 +107,13 @@ def load_dataset_w_config(config_dataset, auth_token):
except (FileNotFoundError, ConnectionError):
pass
# gather extra args from the config
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
else:
load_ds_kwargs["split"] = None
# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
@@ -118,7 +125,7 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
split=None,
**load_ds_kwargs,
)
else:
try:
@@ -130,7 +137,7 @@ def load_dataset_w_config(config_dataset, auth_token):
config_dataset.path,
name=config_dataset.name,
streaming=False,
split=None,
**load_ds_kwargs,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
@@ -140,16 +147,13 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
**load_ds_kwargs,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
@@ -173,9 +177,9 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
@@ -184,9 +188,9 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
else:
if isinstance(config_dataset.data_files, str):
@@ -214,7 +218,7 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
**load_ds_kwargs,
)
if not ds:
raise ValueError("unhandled dataset load")