support passing trust_remote_code to dataset loading (#2050) [skip ci]

* support passing trust_remote_code to dataset loading

* add doc for trust_remote_code in dataset config
This commit is contained in:
Wing Lian
2024-11-15 19:09:48 -05:00
committed by GitHub
parent 521e62daf1
commit 15f1462ccd
2 changed files with 7 additions and 1 deletions

View File

@@ -91,6 +91,7 @@ datasets:
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
# Custom user instruction prompt
- path: repo

View File

@@ -260,6 +260,7 @@ def load_tokenized_prepared_datasets(
for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False
ds_trust_remote_code = config_dataset.trust_remote_code
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
@@ -269,6 +270,7 @@ def load_tokenized_prepared_datasets(
streaming=True,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
@@ -366,7 +368,7 @@ def load_tokenized_prepared_datasets(
elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs = {"split": config_dataset.split}
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
@@ -374,6 +376,7 @@ def load_tokenized_prepared_datasets(
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
@@ -391,6 +394,7 @@ def load_tokenized_prepared_datasets(
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
@@ -401,6 +405,7 @@ def load_tokenized_prepared_datasets(
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
else:
if isinstance(config_dataset.data_files, str):