Files
axolotl/tests/test_datasets.py
kallewoof 92ee4256f7 feature: raise on long sequence drop (#3321)
* feature: raise on long sequence drop

It is sometimes not desired that sequences are silently dropped from the dataset, especially when the dataset has been carefully crafted and pre-fitted for the training context. This would then suggest that an error occurred somewhere in the process. This feature adds a third value for excess_length_strategy called 'raise', which will raise a ValueError if a sequence is encountered that is too long and would have normally been dropped/truncated.

* tests: add excess_length_strategy tests

* doc: updated return value description for drop_long_seq_in_dataset

* add @enable_hf_offline

* fixed cfg modified after validate_config called

* hf offline fix

* fix tqdm desc when raise is used

* test: added test for non-batched case

* accidental code change revert

* test: use pytest.raises

* test: simplified drop_seq_len tests

* test: moved excess_length_strat test to test_data.py

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-12-22 13:59:49 -05:00

490 lines
19 KiB
Python

"""Test dataset loading under various conditions."""
import shutil
import tempfile
from pathlib import Path
from typing import Any, Generator
from unittest.mock import patch
import pytest
from datasets import Dataset
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizer
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import (
_load_tokenized_prepared_datasets,
)
from axolotl.utils.dict import DictDefault
from tests.constants import (
ALPACA_MESSAGES_CONFIG_OG,
ALPACA_MESSAGES_CONFIG_REVISION,
SPECIAL_TOKENS,
)
from tests.hf_offline_utils import enable_hf_offline
class TestDatasetPreparation:
"""Test a configured dataloader."""
@pytest.fixture
def tokenizer(
self, tokenizer_huggyllama
) -> Generator[PreTrainedTokenizer, Any, Any]:
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
yield tokenizer_huggyllama
@pytest.fixture
def dataset_fixture(self):
yield Dataset.from_list(
[
{
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
"input": "He finnished his meal and left the resturant",
"output": "He finished his meal and left the restaurant.",
}
]
)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_hub(self, tokenizer):
"""Core use case. Verify that processing data from the hub works"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
}
)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
@enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub(self, tokenizer):
"""Niche use case. Verify that a local copy of a hub dataset can be loaded"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_path = snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
)
# offline mode doesn't actually copy it to local_dir, so we
# have to copy all the contents in the dir manually from the returned snapshot_path
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
prepared_path = Path(tmp_dir) / "prepared"
# Right now a local copy that doesn't fully conform to a dataset
# must list data_files and ds_type otherwise the loader won't know
# how to load it.
cfg = DictDefault(
{
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
f"{tmp_ds_path}/alpaca_2000.parquet",
],
},
],
}
)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
@enable_hf_offline
def test_load_from_save_to_disk(self, tokenizer, dataset_fixture):
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
dataset_fixture.save_to_disk(str(tmp_ds_name))
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 256,
"datasets": [
{
"path": str(tmp_ds_name),
"type": "alpaca",
},
],
"dataset_num_proc": 4,
}
)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
@enable_hf_offline
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
"""Usual use case. Verify a directory of parquet files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir()
tmp_ds_path = tmp_ds_dir / "shard1.parquet"
dataset_fixture.to_parquet(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 256,
"datasets": [
{
"path": str(tmp_ds_dir),
"ds_type": "parquet",
"name": "test_data",
"data_files": [
str(tmp_ds_path),
],
"type": "alpaca",
},
],
"dataset_num_proc": 4,
}
)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
@enable_hf_offline
def test_load_from_dir_of_json(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a directory of json files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir()
tmp_ds_path = tmp_ds_dir / "shard1.json"
dataset_fixture.to_json(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 256,
"datasets": [
{
"path": str(tmp_ds_dir),
"ds_type": "json",
"name": "test_data",
"data_files": [
str(tmp_ds_path),
],
"type": "alpaca",
},
],
"dataset_num_proc": 4,
}
)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
@enable_hf_offline
def test_load_from_single_parquet(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a single parquet file can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
dataset_fixture.to_parquet(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 256,
"datasets": [
{
"path": str(tmp_ds_path),
"name": "test_data",
"type": "alpaca",
},
],
"dataset_num_proc": 4,
}
)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
@enable_hf_offline
def test_load_from_single_json(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a single json file can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
dataset_fixture.to_json(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 256,
"datasets": [
{
"path": str(tmp_ds_path),
"name": "test_data",
"type": "alpaca",
},
],
"dataset_num_proc": 4,
}
)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
@enable_hf_offline
def test_load_hub_with_dpo(self):
"""Verify that processing dpo data from the hub works"""
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"datasets": [ALPACA_MESSAGES_CONFIG_OG],
}
)
tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
assert len(train_dataset) == 1800
assert "conversation" not in train_dataset.features
assert "chosen" in train_dataset.features
assert "rejected" in train_dataset.features
assert "prompt" in train_dataset.features
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_hub_with_revision(self, tokenizer):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
# make sure prepared_path is empty
shutil.rmtree(prepared_path, ignore_errors=True)
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"revision": "d05c1cb",
},
],
}
)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
@enable_hf_offline
def test_load_hub_with_revision_with_dpo(
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
):
"""Verify that processing dpo data from the hub works with a specific revision"""
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
"dataset_num_proc": 4,
}
)
with patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
)
tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
assert len(train_dataset) == 1800
assert "conversation" not in train_dataset.features
assert "chosen" in train_dataset.features
assert "rejected" in train_dataset.features
assert "prompt" in train_dataset.features
@enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub_with_revision(
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, tokenizer
):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_path = snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
f"{tmp_ds_path}/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
with patch(
"axolotl.utils.data.shared.load_dataset_with_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH",
str(prepared_path),
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
@enable_hf_offline
def test_loading_local_dataset_folder(self, tokenizer):
"""Verify that a dataset downloaded to a local folder can be loaded"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_path = snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
)
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": str(tmp_ds_path),
"type": "alpaca",
},
],
"dataset_num_proc": 4,
}
)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)