Take split param from config in all load_dataset instances (#2281)
This commit is contained in:
@@ -107,6 +107,13 @@ def load_dataset_w_config(config_dataset, auth_token):
|
|||||||
except (FileNotFoundError, ConnectionError):
|
except (FileNotFoundError, ConnectionError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# gather extra args from the config
|
||||||
|
load_ds_kwargs = {}
|
||||||
|
if config_dataset.split:
|
||||||
|
load_ds_kwargs["split"] = config_dataset.split
|
||||||
|
else:
|
||||||
|
load_ds_kwargs["split"] = None
|
||||||
|
|
||||||
# prefer local dataset, even if hub exists
|
# prefer local dataset, even if hub exists
|
||||||
local_path = Path(config_dataset.path)
|
local_path = Path(config_dataset.path)
|
||||||
if local_path.exists():
|
if local_path.exists():
|
||||||
@@ -118,7 +125,7 @@ def load_dataset_w_config(config_dataset, auth_token):
|
|||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
data_files=config_dataset.data_files,
|
data_files=config_dataset.data_files,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
**load_ds_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
@@ -130,7 +137,7 @@ def load_dataset_w_config(config_dataset, auth_token):
|
|||||||
config_dataset.path,
|
config_dataset.path,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
**load_ds_kwargs,
|
||||||
)
|
)
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
ds_type = get_ds_type(config_dataset)
|
ds_type = get_ds_type(config_dataset)
|
||||||
@@ -140,16 +147,13 @@ def load_dataset_w_config(config_dataset, auth_token):
|
|||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
data_files=config_dataset.path,
|
data_files=config_dataset.path,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
**load_ds_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
||||||
)
|
)
|
||||||
elif ds_from_hub:
|
elif ds_from_hub:
|
||||||
load_ds_kwargs = {}
|
|
||||||
if config_dataset.split:
|
|
||||||
load_ds_kwargs["split"] = config_dataset.split
|
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
config_dataset.path,
|
config_dataset.path,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
@@ -173,9 +177,9 @@ def load_dataset_w_config(config_dataset, auth_token):
|
|||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
data_files=config_dataset.path,
|
data_files=config_dataset.path,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
trust_remote_code=config_dataset.trust_remote_code,
|
trust_remote_code=config_dataset.trust_remote_code,
|
||||||
|
**load_ds_kwargs,
|
||||||
)
|
)
|
||||||
elif config_dataset.path.startswith("https://"):
|
elif config_dataset.path.startswith("https://"):
|
||||||
ds_type = get_ds_type(config_dataset)
|
ds_type = get_ds_type(config_dataset)
|
||||||
@@ -184,9 +188,9 @@ def load_dataset_w_config(config_dataset, auth_token):
|
|||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
data_files=config_dataset.path,
|
data_files=config_dataset.path,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
trust_remote_code=config_dataset.trust_remote_code,
|
trust_remote_code=config_dataset.trust_remote_code,
|
||||||
|
**load_ds_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if isinstance(config_dataset.data_files, str):
|
if isinstance(config_dataset.data_files, str):
|
||||||
@@ -214,7 +218,7 @@ def load_dataset_w_config(config_dataset, auth_token):
|
|||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
data_files=fp,
|
data_files=fp,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
**load_ds_kwargs,
|
||||||
)
|
)
|
||||||
if not ds:
|
if not ds:
|
||||||
raise ValueError("unhandled dataset load")
|
raise ValueError("unhandled dataset load")
|
||||||
|
|||||||
Reference in New Issue
Block a user