From 15f1462ccd7a5505cec0f27e8087956505596b58 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 15 Nov 2024 19:09:48 -0500 Subject: [PATCH] support passing trust_remote_code to dataset loading (#2050) [skip ci] * support passing trust_remote_code to dataset loading * add doc for trust_remote_code in dataset config --- docs/config.qmd | 1 + src/axolotl/utils/data/sft.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/config.qmd b/docs/config.qmd index 4349f0f09..04e278e2d 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -91,6 +91,7 @@ datasets: name: # Optional[str] name of dataset configuration to load train_on_split: train # Optional[str] name of dataset split to load from revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets. + trust_remote_code: # Optional[bool] Trust remote code for untrusted source # Custom user instruction prompt - path: repo diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index ec17fb9c2..e05c02966 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -260,6 +260,7 @@ def load_tokenized_prepared_datasets( for config_dataset in for_d_in_datasets(cfg_datasets): ds: Optional[Union[Dataset, DatasetDict]] = None ds_from_hub = False + ds_trust_remote_code = config_dataset.trust_remote_code try: # this is just a basic check to see if the path is a # valid HF dataset that's loadable @@ -269,6 +270,7 @@ def load_tokenized_prepared_datasets( streaming=True, token=use_auth_token, revision=config_dataset.revision, + trust_remote_code=ds_trust_remote_code, ) ds_from_hub = True except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): @@ -366,7 +368,7 @@ def load_tokenized_prepared_datasets( elif ds_from_hub: load_ds_kwargs = {} if config_dataset.split: - load_ds_kwargs = {"split": config_dataset.split} + load_ds_kwargs["split"] = config_dataset.split ds = load_dataset( config_dataset.path, name=config_dataset.name, @@ -374,6 +376,7 @@ def load_tokenized_prepared_datasets( data_files=config_dataset.data_files, token=use_auth_token, revision=config_dataset.revision, + trust_remote_code=config_dataset.trust_remote_code, **load_ds_kwargs, ) elif ds_from_cloud and remote_file_system: @@ -391,6 +394,7 @@ def load_tokenized_prepared_datasets( streaming=False, split=None, storage_options=storage_options, + trust_remote_code=config_dataset.trust_remote_code, ) elif config_dataset.path.startswith("https://"): ds_type = get_ds_type(config_dataset) @@ -401,6 +405,7 @@ def load_tokenized_prepared_datasets( streaming=False, split=None, storage_options=storage_options, + trust_remote_code=config_dataset.trust_remote_code, ) else: if isinstance(config_dataset.data_files, str):