* feat: support datasets saved in str format * add also str for tools * format * fix: address comments + add unit test * format
557 lines
21 KiB
Python
557 lines
21 KiB
Python
"""Test dataset loading under various conditions."""
|
|
|
|
import shutil
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Any, Generator
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from datasets import Dataset
|
|
from huggingface_hub import snapshot_download
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
from axolotl.loaders.tokenizer import load_tokenizer
|
|
from axolotl.utils.data.rl import prepare_preference_datasets
|
|
from axolotl.utils.data.sft import (
|
|
_load_tokenized_prepared_datasets,
|
|
)
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
from tests.constants import (
|
|
ALPACA_MESSAGES_CONFIG_OG,
|
|
ALPACA_MESSAGES_CONFIG_REVISION,
|
|
SPECIAL_TOKENS,
|
|
)
|
|
from tests.hf_offline_utils import enable_hf_offline
|
|
|
|
|
|
class TestDatasetPreparation:
|
|
"""Test a configured dataloader."""
|
|
|
|
@pytest.fixture
|
|
def tokenizer(
|
|
self, tokenizer_huggyllama
|
|
) -> Generator[PreTrainedTokenizer, Any, Any]:
|
|
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
|
|
yield tokenizer_huggyllama
|
|
|
|
@pytest.fixture
|
|
def dataset_fixture(self):
|
|
yield Dataset.from_list(
|
|
[
|
|
{
|
|
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
|
|
"input": "He finnished his meal and left the resturant",
|
|
"output": "He finished his meal and left the restaurant.",
|
|
}
|
|
]
|
|
)
|
|
|
|
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
|
@enable_hf_offline
|
|
def test_load_hub(self, tokenizer):
|
|
"""Core use case. Verify that processing data from the hub works"""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
prepared_path = Path(tmp_dir) / "prepared"
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "huggyllama/llama-7b",
|
|
"sequence_len": 1024,
|
|
"datasets": [
|
|
{
|
|
"path": "mhenrichsen/alpaca_2k_test",
|
|
"type": "alpaca",
|
|
},
|
|
],
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
|
):
|
|
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
|
|
|
assert len(dataset) == 2000
|
|
assert "input_ids" in dataset.features
|
|
assert "attention_mask" in dataset.features
|
|
assert "labels" in dataset.features
|
|
|
|
@enable_hf_offline
|
|
@pytest.mark.skip("datasets bug with local datasets when offline")
|
|
def test_load_local_hub(self, tokenizer):
|
|
"""Niche use case. Verify that a local copy of a hub dataset can be loaded"""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
|
snapshot_path = snapshot_download(
|
|
repo_id="mhenrichsen/alpaca_2k_test",
|
|
repo_type="dataset",
|
|
local_dir=tmp_ds_path,
|
|
)
|
|
# offline mode doesn't actually copy it to local_dir, so we
|
|
# have to copy all the contents in the dir manually from the returned snapshot_path
|
|
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
|
|
|
|
prepared_path = Path(tmp_dir) / "prepared"
|
|
# Right now a local copy that doesn't fully conform to a dataset
|
|
# must list data_files and ds_type otherwise the loader won't know
|
|
# how to load it.
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
|
"sequence_len": 1024,
|
|
"datasets": [
|
|
{
|
|
"path": "mhenrichsen/alpaca_2k_test",
|
|
"ds_type": "parquet",
|
|
"type": "alpaca",
|
|
"data_files": [
|
|
f"{tmp_ds_path}/alpaca_2000.parquet",
|
|
],
|
|
},
|
|
],
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
|
):
|
|
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
|
|
|
assert len(dataset) == 2000
|
|
assert "input_ids" in dataset.features
|
|
assert "attention_mask" in dataset.features
|
|
assert "labels" in dataset.features
|
|
shutil.rmtree(tmp_ds_path)
|
|
|
|
@enable_hf_offline
|
|
def test_load_from_save_to_disk(self, tokenizer, dataset_fixture):
|
|
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
|
dataset_fixture.save_to_disk(str(tmp_ds_name))
|
|
|
|
prepared_path = Path(tmp_dir) / "prepared"
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "huggyllama/llama-7b",
|
|
"sequence_len": 256,
|
|
"datasets": [
|
|
{
|
|
"path": str(tmp_ds_name),
|
|
"type": "alpaca",
|
|
},
|
|
],
|
|
"dataset_num_proc": 4,
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
|
):
|
|
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
|
|
|
assert len(dataset) == 1
|
|
assert "input_ids" in dataset.features
|
|
assert "attention_mask" in dataset.features
|
|
assert "labels" in dataset.features
|
|
|
|
@enable_hf_offline
|
|
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
|
|
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
|
tmp_ds_dir.mkdir()
|
|
tmp_ds_path = tmp_ds_dir / "shard1.parquet"
|
|
dataset_fixture.to_parquet(tmp_ds_path)
|
|
|
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "huggyllama/llama-7b",
|
|
"sequence_len": 256,
|
|
"datasets": [
|
|
{
|
|
"path": str(tmp_ds_dir),
|
|
"ds_type": "parquet",
|
|
"name": "test_data",
|
|
"data_files": [
|
|
str(tmp_ds_path),
|
|
],
|
|
"type": "alpaca",
|
|
},
|
|
],
|
|
"dataset_num_proc": 4,
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
|
):
|
|
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
|
|
|
assert len(dataset) == 1
|
|
assert "input_ids" in dataset.features
|
|
assert "attention_mask" in dataset.features
|
|
assert "labels" in dataset.features
|
|
|
|
@enable_hf_offline
|
|
def test_load_from_dir_of_json(self, tokenizer, dataset_fixture):
|
|
"""Standard use case. Verify a directory of json files can be loaded."""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
|
tmp_ds_dir.mkdir()
|
|
tmp_ds_path = tmp_ds_dir / "shard1.json"
|
|
dataset_fixture.to_json(tmp_ds_path)
|
|
|
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "huggyllama/llama-7b",
|
|
"sequence_len": 256,
|
|
"datasets": [
|
|
{
|
|
"path": str(tmp_ds_dir),
|
|
"ds_type": "json",
|
|
"name": "test_data",
|
|
"data_files": [
|
|
str(tmp_ds_path),
|
|
],
|
|
"type": "alpaca",
|
|
},
|
|
],
|
|
"dataset_num_proc": 4,
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
|
):
|
|
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
|
|
|
assert len(dataset) == 1
|
|
assert "input_ids" in dataset.features
|
|
assert "attention_mask" in dataset.features
|
|
assert "labels" in dataset.features
|
|
|
|
@enable_hf_offline
|
|
def test_load_from_single_parquet(self, tokenizer, dataset_fixture):
|
|
"""Standard use case. Verify a single parquet file can be loaded."""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
|
|
dataset_fixture.to_parquet(tmp_ds_path)
|
|
|
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "huggyllama/llama-7b",
|
|
"sequence_len": 256,
|
|
"datasets": [
|
|
{
|
|
"path": str(tmp_ds_path),
|
|
"name": "test_data",
|
|
"type": "alpaca",
|
|
},
|
|
],
|
|
"dataset_num_proc": 4,
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
|
):
|
|
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
|
|
|
assert len(dataset) == 1
|
|
assert "input_ids" in dataset.features
|
|
assert "attention_mask" in dataset.features
|
|
assert "labels" in dataset.features
|
|
|
|
@enable_hf_offline
|
|
def test_load_from_single_json(self, tokenizer, dataset_fixture):
|
|
"""Standard use case. Verify a single json file can be loaded."""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
|
|
dataset_fixture.to_json(tmp_ds_path)
|
|
|
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "huggyllama/llama-7b",
|
|
"sequence_len": 256,
|
|
"datasets": [
|
|
{
|
|
"path": str(tmp_ds_path),
|
|
"name": "test_data",
|
|
"type": "alpaca",
|
|
},
|
|
],
|
|
"dataset_num_proc": 4,
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
|
):
|
|
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
|
|
|
assert len(dataset) == 1
|
|
assert "input_ids" in dataset.features
|
|
assert "attention_mask" in dataset.features
|
|
assert "labels" in dataset.features
|
|
|
|
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
|
|
@enable_hf_offline
|
|
def test_load_hub_with_dpo(self):
|
|
"""Verify that processing dpo data from the hub works"""
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "huggyllama/llama-7b",
|
|
"sequence_len": 1024,
|
|
"rl": "dpo",
|
|
"chat_template": "llama3",
|
|
"datasets": [ALPACA_MESSAGES_CONFIG_OG],
|
|
}
|
|
)
|
|
|
|
tokenizer = load_tokenizer(cfg)
|
|
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
|
|
|
assert len(train_dataset) == 1800
|
|
assert "conversation" not in train_dataset.features
|
|
assert "chosen" in train_dataset.features
|
|
assert "rejected" in train_dataset.features
|
|
assert "prompt" in train_dataset.features
|
|
|
|
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
|
@enable_hf_offline
|
|
def test_load_hub_with_revision(self, tokenizer):
|
|
"""Verify that processing data from the hub works with a specific revision"""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
prepared_path = Path(tmp_dir) / "prepared"
|
|
|
|
# make sure prepared_path is empty
|
|
shutil.rmtree(prepared_path, ignore_errors=True)
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "huggyllama/llama-7b",
|
|
"sequence_len": 1024,
|
|
"datasets": [
|
|
{
|
|
"path": "mhenrichsen/alpaca_2k_test",
|
|
"type": "alpaca",
|
|
"revision": "d05c1cb",
|
|
},
|
|
],
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
|
):
|
|
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
|
|
|
assert len(dataset) == 2000
|
|
assert "input_ids" in dataset.features
|
|
assert "attention_mask" in dataset.features
|
|
assert "labels" in dataset.features
|
|
|
|
@enable_hf_offline
|
|
def test_load_hub_with_revision_with_dpo(
|
|
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
|
):
|
|
"""Verify that processing dpo data from the hub works with a specific revision"""
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "huggyllama/llama-7b",
|
|
"sequence_len": 1024,
|
|
"rl": "dpo",
|
|
"chat_template": "llama3",
|
|
"datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
|
|
"dataset_num_proc": 4,
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.utils.data.rl.load_dataset_with_config"
|
|
) as mock_load_dataset:
|
|
# Set up the mock to return different values on successive calls
|
|
mock_load_dataset.return_value = (
|
|
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
|
)
|
|
|
|
tokenizer = load_tokenizer(cfg)
|
|
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
|
|
|
assert len(train_dataset) == 1800
|
|
assert "conversation" not in train_dataset.features
|
|
assert "chosen" in train_dataset.features
|
|
assert "rejected" in train_dataset.features
|
|
assert "prompt" in train_dataset.features
|
|
|
|
@enable_hf_offline
|
|
@pytest.mark.skip("datasets bug with local datasets when offline")
|
|
def test_load_local_hub_with_revision(
|
|
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, tokenizer
|
|
):
|
|
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
|
snapshot_path = snapshot_download(
|
|
repo_id="mhenrichsen/alpaca_2k_test",
|
|
repo_type="dataset",
|
|
local_dir=tmp_ds_path,
|
|
revision="d05c1cb",
|
|
)
|
|
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
|
|
|
|
prepared_path = Path(tmp_dir) / "prepared"
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "huggyllama/llama-7b",
|
|
"sequence_len": 1024,
|
|
"datasets": [
|
|
{
|
|
"path": "mhenrichsen/alpaca_2k_test",
|
|
"ds_type": "parquet",
|
|
"type": "alpaca",
|
|
"data_files": [
|
|
f"{tmp_ds_path}/alpaca_2000.parquet",
|
|
],
|
|
"revision": "d05c1cb",
|
|
},
|
|
],
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.utils.data.shared.load_dataset_with_config"
|
|
) as mock_load_dataset:
|
|
# Set up the mock to return different values on successive calls
|
|
mock_load_dataset.return_value = (
|
|
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH",
|
|
str(prepared_path),
|
|
):
|
|
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
|
|
|
assert len(dataset) == 2000
|
|
assert "input_ids" in dataset.features
|
|
assert "attention_mask" in dataset.features
|
|
assert "labels" in dataset.features
|
|
shutil.rmtree(tmp_ds_path)
|
|
|
|
@enable_hf_offline
|
|
def test_loading_local_dataset_folder(self, tokenizer):
|
|
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
|
snapshot_path = snapshot_download(
|
|
repo_id="mhenrichsen/alpaca_2k_test",
|
|
repo_type="dataset",
|
|
)
|
|
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
|
|
|
|
prepared_path = Path(tmp_dir) / "prepared"
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "huggyllama/llama-7b",
|
|
"sequence_len": 1024,
|
|
"datasets": [
|
|
{
|
|
"path": str(tmp_ds_path),
|
|
"type": "alpaca",
|
|
},
|
|
],
|
|
"dataset_num_proc": 4,
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
|
):
|
|
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
|
|
|
assert len(dataset) == 2000
|
|
assert "input_ids" in dataset.features
|
|
assert "attention_mask" in dataset.features
|
|
assert "labels" in dataset.features
|
|
shutil.rmtree(tmp_ds_path)
|
|
|
|
@enable_hf_offline
|
|
def test_load_dataset_with_str_json_data(self, tokenizer):
|
|
"""
|
|
Test loading datasets where data is stored as str JSON instead of list of dicts.
|
|
see: https://github.com/axolotl-ai-cloud/axolotl/pull/3607 for more details.
|
|
"""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
import json
|
|
|
|
str_json_ds = Dataset.from_list(
|
|
[
|
|
{
|
|
"messages": json.dumps(
|
|
[
|
|
{"role": "user", "content": "Hello how are you?"},
|
|
{
|
|
"role": "assistant",
|
|
"content": "I am doing good thanks",
|
|
},
|
|
]
|
|
)
|
|
},
|
|
{
|
|
"messages": json.dumps(
|
|
[
|
|
{"role": "user", "content": "What is 2+2?"},
|
|
{"role": "assistant", "content": "2+2 equals 4."},
|
|
]
|
|
)
|
|
},
|
|
]
|
|
)
|
|
|
|
tmp_ds_path = Path(tmp_dir) / "str_json_dataset.parquet"
|
|
str_json_ds.to_parquet(tmp_ds_path)
|
|
|
|
prepared_path = Path(tmp_dir) / "prepared"
|
|
cfg = DictDefault(
|
|
{
|
|
"tokenizer_config": "huggyllama/llama-7b",
|
|
"sequence_len": 512,
|
|
"datasets": [
|
|
{
|
|
"path": str(tmp_ds_path),
|
|
"name": "test_str_json",
|
|
"type": "chat_template",
|
|
"field_messages": "messages",
|
|
"message_field_role": "role",
|
|
"message_field_content": "content",
|
|
},
|
|
],
|
|
"dataset_num_proc": 4,
|
|
}
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
|
):
|
|
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
|
|
|
assert len(dataset) == 2
|
|
assert "input_ids" in dataset.features
|
|
assert "attention_mask" in dataset.features
|
|
assert "labels" in dataset.features
|
|
|
|
assert len(dataset[0]["input_ids"]) > 0
|