diff --git a/docs/config.qmd b/docs/config.qmd index 99a69a097..8329f3553 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -90,6 +90,7 @@ datasets: shards: # Optional[int] number of shards to split data into name: # Optional[str] name of dataset configuration to load train_on_split: train # Optional[str] name of dataset split to load from + revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets. # Optional[str] fastchat conversation type, only used with type: sharegpt conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 7d6922cbf..f46152a7c 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -242,6 +242,7 @@ def load_tokenized_prepared_datasets( name=config_dataset.name, streaming=True, token=use_auth_token, + revision=config_dataset.revision, ) ds_from_hub = True except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): @@ -319,6 +320,7 @@ def load_tokenized_prepared_datasets( data_files=config_dataset.data_files, streaming=False, split=None, + revision=config_dataset.revision, ) else: ds = load_from_disk(config_dataset.path) @@ -331,6 +333,7 @@ def load_tokenized_prepared_datasets( data_files=config_dataset.path, streaming=False, split=None, + revision=config_dataset.revision, ) else: raise ValueError( @@ -346,6 +349,7 @@ def load_tokenized_prepared_datasets( streaming=False, data_files=config_dataset.data_files, token=use_auth_token, + revision=config_dataset.revision, **load_ds_kwargs, ) elif ds_from_cloud and remote_file_system: @@ -363,6 +367,7 @@ def load_tokenized_prepared_datasets( streaming=False, split=None, storage_options=storage_options, + revision=config_dataset.revision, ) elif config_dataset.path.startswith("https://"): ds_type = get_ds_type(config_dataset) @@ -373,6 +378,7 @@ def load_tokenized_prepared_datasets( streaming=False, split=None, storage_options=storage_options, + revision=config_dataset.revision, ) else: if isinstance(config_dataset.data_files, str): @@ -380,6 +386,7 @@ def load_tokenized_prepared_datasets( repo_id=config_dataset.path, repo_type="dataset", filename=config_dataset.data_files, + revision=config_dataset.revision, ) elif isinstance(config_dataset.data_files, list): fp = [] @@ -389,6 +396,7 @@ def load_tokenized_prepared_datasets( repo_id=config_dataset.path, repo_type="dataset", filename=file, + revision=config_dataset.revision, ) ) else: @@ -433,8 +441,8 @@ def load_tokenized_prepared_datasets( config_dataset=config_dataset, tokenizer=tokenizer, cfg=cfg, - dataset=ds, d_base_type=d_base_type, + dataset=ds, d_prompt_style=d_prompt_style, processor=processor, ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index a274b7b89..5a631b2e6 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -267,6 +267,73 @@ class TestDatasetPreparation(unittest.TestCase): assert "attention_mask" in dataset.features assert "labels" in dataset.features + def test_load_hub_with_revision(self): + """Verify that processing data from the hub works with a specific revision""" + with tempfile.TemporaryDirectory() as tmp_dir: + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + "revision": "foo", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + + def test_load_local_hub_with_revision(self): + """Verify that a local copy of a hub dataset can be loaded with a specific revision""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_ds_path = Path("mhenrichsen/alpaca_2k_test") + tmp_ds_path.mkdir(parents=True, exist_ok=True) + snapshot_download( + repo_id="mhenrichsen/alpaca_2k_test", + repo_type="dataset", + local_dir=tmp_ds_path, + revision="foo", + ) + + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "ds_type": "parquet", + "type": "alpaca", + "data_files": [ + "mhenrichsen/alpaca_2k_test/alpaca_2000.parquet", + ], + "revision": "foo", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + shutil.rmtree(tmp_ds_path) if __name__ == "__main__": unittest.main()