Compare commits

...

4 Commits

Author SHA1 Message Date
Wing Lian
d1f36d7b78 set download to use revision 2024-10-11 11:03:43 -04:00
Wing Lian
87248027d0 use revision tied to head 2024-10-11 11:03:43 -04:00
Wing Lian
d0d22b7812 only use revision on hf hub backed datasets 2024-10-11 11:03:43 -04:00
Thomas Cleberg
68db5b1b67 Add support for revision dataset parameter 2024-10-11 11:03:43 -04:00
3 changed files with 74 additions and 1 deletions

View File

@@ -90,6 +90,7 @@ datasets:
shards: # Optional[int] number of shards to split data into shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
# Optional[str] fastchat conversation type, only used with type: sharegpt # Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py

View File

@@ -242,6 +242,7 @@ def load_tokenized_prepared_datasets(
name=config_dataset.name, name=config_dataset.name,
streaming=True, streaming=True,
token=use_auth_token, token=use_auth_token,
revision=config_dataset.revision,
) )
ds_from_hub = True ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
@@ -346,6 +347,7 @@ def load_tokenized_prepared_datasets(
streaming=False, streaming=False,
data_files=config_dataset.data_files, data_files=config_dataset.data_files,
token=use_auth_token, token=use_auth_token,
revision=config_dataset.revision,
**load_ds_kwargs, **load_ds_kwargs,
) )
elif ds_from_cloud and remote_file_system: elif ds_from_cloud and remote_file_system:
@@ -380,6 +382,7 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path, repo_id=config_dataset.path,
repo_type="dataset", repo_type="dataset",
filename=config_dataset.data_files, filename=config_dataset.data_files,
revision=config_dataset.revision,
) )
elif isinstance(config_dataset.data_files, list): elif isinstance(config_dataset.data_files, list):
fp = [] fp = []
@@ -389,6 +392,7 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path, repo_id=config_dataset.path,
repo_type="dataset", repo_type="dataset",
filename=file, filename=file,
revision=config_dataset.revision,
) )
) )
else: else:
@@ -433,8 +437,8 @@ def load_tokenized_prepared_datasets(
config_dataset=config_dataset, config_dataset=config_dataset,
tokenizer=tokenizer, tokenizer=tokenizer,
cfg=cfg, cfg=cfg,
dataset=ds,
d_base_type=d_base_type, d_base_type=d_base_type,
dataset=ds,
d_prompt_style=d_prompt_style, d_prompt_style=d_prompt_style,
processor=processor, processor=processor,
) )

View File

@@ -267,6 +267,74 @@ class TestDatasetPreparation(unittest.TestCase):
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_hub_with_revision(self):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()