Add Support for revision Dataset Parameter to specify reading from Huggingface Dataset Revision (#1912)
* Add support for `revision` dataset parameter * only use revision on hf hub backed datasets * use revision tied to head * set download to use revision * feat: add config to model validator class * feat: add revision config to RL and tests for it --------- Co-authored-by: Wing Lian <wing.lian@gmail.com> Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -267,6 +268,143 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
|
||||
def test_load_hub_with_dpo(self):
|
||||
"""Verify that processing dpo data from the hub works"""
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"rl": "dpo",
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"type": "chat_template.default",
|
||||
"chat_template": "llama3",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_dpo_datasets(cfg)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" in train_dataset.features
|
||||
|
||||
def test_load_hub_with_revision(self):
|
||||
"""Verify that processing data from the hub works with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
"revision": "d05c1cb",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
|
||||
def test_load_hub_with_revision_with_dpo(self):
|
||||
"""Verify that processing dpo data from the hub works with a specific revision"""
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"rl": "dpo",
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"type": "chat_template.default",
|
||||
"chat_template": "llama3",
|
||||
"revision": "ea82cff",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_dpo_datasets(cfg)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" in train_dataset.features
|
||||
|
||||
def test_load_local_hub_with_revision(self):
|
||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
|
||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
repo_type="dataset",
|
||||
local_dir=tmp_ds_path,
|
||||
revision="d05c1cb",
|
||||
)
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"ds_type": "parquet",
|
||||
"type": "alpaca",
|
||||
"data_files": [
|
||||
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
|
||||
],
|
||||
"revision": "d05c1cb",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user