Take split param from config in all load_dataset instances (#2281)

This commit is contained in:
mashdragon
2025-01-24 15:06:50 +00:00
committed by GitHub
parent 74f9782fc3
commit b2774af66c

View File

@@ -107,6 +107,13 @@ def load_dataset_w_config(config_dataset, auth_token):
except (FileNotFoundError, ConnectionError): except (FileNotFoundError, ConnectionError):
pass pass
# gather extra args from the config
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
else:
load_ds_kwargs["split"] = None
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
local_path = Path(config_dataset.path) local_path = Path(config_dataset.path)
if local_path.exists(): if local_path.exists():
@@ -118,7 +125,7 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.data_files, data_files=config_dataset.data_files,
streaming=False, streaming=False,
split=None, **load_ds_kwargs,
) )
else: else:
try: try:
@@ -130,7 +137,7 @@ def load_dataset_w_config(config_dataset, auth_token):
config_dataset.path, config_dataset.path,
name=config_dataset.name, name=config_dataset.name,
streaming=False, streaming=False,
split=None, **load_ds_kwargs,
) )
elif local_path.is_file(): elif local_path.is_file():
ds_type = get_ds_type(config_dataset) ds_type = get_ds_type(config_dataset)
@@ -140,16 +147,13 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.path, data_files=config_dataset.path,
streaming=False, streaming=False,
split=None, **load_ds_kwargs,
) )
else: else:
raise ValueError( raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file" "unhandled dataset load: local path exists, but is neither a directory or a file"
) )
elif ds_from_hub: elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset( ds = load_dataset(
config_dataset.path, config_dataset.path,
name=config_dataset.name, name=config_dataset.name,
@@ -173,9 +177,9 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.path, data_files=config_dataset.path,
streaming=False, streaming=False,
split=None,
storage_options=storage_options, storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code, trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
) )
elif config_dataset.path.startswith("https://"): elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset) ds_type = get_ds_type(config_dataset)
@@ -184,9 +188,9 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.path, data_files=config_dataset.path,
streaming=False, streaming=False,
split=None,
storage_options=storage_options, storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code, trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
) )
else: else:
if isinstance(config_dataset.data_files, str): if isinstance(config_dataset.data_files, str):
@@ -214,7 +218,7 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=fp, data_files=fp,
streaming=False, streaming=False,
split=None, **load_ds_kwargs,
) )
if not ds: if not ds:
raise ValueError("unhandled dataset load") raise ValueError("unhandled dataset load")