From 1c33eb88a7eff995aa3440c15a88074fe95eb877 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 28 May 2023 13:08:49 -0400 Subject: [PATCH] new hf_use_auth_token setting so login to hf isn't required --- README.md | 3 +++ src/axolotl/utils/data.py | 12 +++++++----- src/axolotl/utils/validation.py | 3 +++ tests/test_validation.py | 26 ++++++++++++++++++++++++++ 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index c8874bd7f..adc3c5812 100644 --- a/README.md +++ b/README.md @@ -207,6 +207,9 @@ datasets: dataset_prepared_path: data/last_run_prepared # push prepared dataset to hub push_dataset_to_hub: # repo path +# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets +# required to be true when used in combination with `push_dataset_to_hub` +hf_use_auth_token: # boolean # How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc val_set_size: 0.04 # Num shards for whole dataset diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index b2045c229..a0cff21c4 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -61,10 +61,11 @@ def load_tokenized_prepared_datasets( else Path(default_dataset_prepared_path) / ds_hash ) dataset = None + use_auth_token = cfg.hf_use_auth_token try: if cfg.push_dataset_to_hub: dataset = load_dataset( - f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True + f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token ) dataset = dataset["train"] except: @@ -84,7 +85,7 @@ def load_tokenized_prepared_datasets( ds: Union[Dataset, DatasetDict] = None ds_from_hub = False try: - load_dataset(d.path, streaming=True, use_auth_token=True) + load_dataset(d.path, streaming=True, use_auth_token=use_auth_token) ds_from_hub = True except FileNotFoundError: pass @@ -100,10 +101,10 @@ def load_tokenized_prepared_datasets( d.path, streaming=False, data_files=d.data_files, - use_auth_token=True, + use_auth_token=use_auth_token, ) else: - ds: Dataset = load_dataset(d.path, streaming=False, use_auth_token=True) + ds: Dataset = load_dataset(d.path, streaming=False, use_auth_token=use_auth_token) else: fp = hf_hub_download( repo_id=d.path, repo_type="dataset", filename=d.data_files @@ -274,13 +275,14 @@ def load_prepare_datasets( ) dataset = None + use_auth_token = cfg.hf_use_auth_token try: if cfg.push_dataset_to_hub: logging.info( f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}" ) dataset = load_dataset( - f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True + f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token ) dataset = dataset["train"] except: diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 97cde0677..bc2940d5e 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -37,6 +37,9 @@ def validate_config(cfg): "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." ) + if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True: + raise ValueError("Require cfg.hf_use_auth_token to be True for push_dataset_to_hub") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/test_validation.py b/tests/test_validation.py index e754b0ea7..71bed89aa 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -93,3 +93,29 @@ class ValidationTest(unittest.TestCase): with pytest.raises(ValueError, match=r".*4bit.*"): validate_config(cfg) + + def test_hf_use_auth_token(self): + base_cfg = DictDefault( + { + "push_dataset_to_hub": None, + "hf_use_auth_token": None, + } + ) + + cfg = base_cfg | DictDefault( + { + "push_dataset_to_hub": "namespace/repo", + } + ) + + with pytest.raises(ValueError, match=r".*hf_use_auth_token.*"): + validate_config(cfg) + + cfg = base_cfg | DictDefault( + { + "push_dataset_to_hub": "namespace/repo", + "hf_use_auth_token": True, + } + ) + validate_config(cfg) +