Take split param from config in all load_dataset instances (#2281)
This commit is contained in:
@@ -107,6 +107,13 @@ def load_dataset_w_config(config_dataset, auth_token):
|
||||
except (FileNotFoundError, ConnectionError):
|
||||
pass
|
||||
|
||||
# gather extra args from the config
|
||||
load_ds_kwargs = {}
|
||||
if config_dataset.split:
|
||||
load_ds_kwargs["split"] = config_dataset.split
|
||||
else:
|
||||
load_ds_kwargs["split"] = None
|
||||
|
||||
# prefer local dataset, even if hub exists
|
||||
local_path = Path(config_dataset.path)
|
||||
if local_path.exists():
|
||||
@@ -118,7 +125,7 @@ def load_dataset_w_config(config_dataset, auth_token):
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
@@ -130,7 +137,7 @@ def load_dataset_w_config(config_dataset, auth_token):
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=False,
|
||||
split=None,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
@@ -140,16 +147,13 @@ def load_dataset_w_config(config_dataset, auth_token):
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
||||
)
|
||||
elif ds_from_hub:
|
||||
load_ds_kwargs = {}
|
||||
if config_dataset.split:
|
||||
load_ds_kwargs["split"] = config_dataset.split
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
@@ -173,9 +177,9 @@ def load_dataset_w_config(config_dataset, auth_token):
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif config_dataset.path.startswith("https://"):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
@@ -184,9 +188,9 @@ def load_dataset_w_config(config_dataset, auth_token):
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
else:
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
@@ -214,7 +218,7 @@ def load_dataset_w_config(config_dataset, auth_token):
|
||||
name=config_dataset.name,
|
||||
data_files=fp,
|
||||
streaming=False,
|
||||
split=None,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
if not ds:
|
||||
raise ValueError("unhandled dataset load")
|
||||
|
||||
Reference in New Issue
Block a user