Data loader refactor (#2707)
* data loading refactor (wip) * updates * progress * pytest * pytest fix * lint * zero_first -> filelock, more simplifications * small simplification * import change * nit * lint * simplify dedup * couldnt resist * review comments WIP * continued wip * minor changes * fix; remove contrived test * further refactor * set default seed in pydantic config * lint * continued simplication * lint * renaming and nits * filelock tests * fix * fix * lint * remove nullable arg * remove unnecessary code * moving dataset save fn to shared module * remove debug print * matching var naming * fn name change * coderabbit comments * naming nit * fix test
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
"""
|
||||
Test dataset loading under various conditions.
|
||||
"""
|
||||
"""Test dataset loading under various conditions."""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -12,8 +11,9 @@ from datasets import Dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.utils.data.rl import prepare_preference_datasets
|
||||
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.constants import (
|
||||
@@ -28,7 +28,9 @@ class TestDatasetPreparation:
|
||||
"""Test a configured dataloader."""
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
|
||||
def tokenizer(
|
||||
self, tokenizer_huggyllama
|
||||
) -> Generator[PreTrainedTokenizer, Any, Any]:
|
||||
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
|
||||
yield tokenizer_huggyllama
|
||||
|
||||
@@ -63,7 +65,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -107,7 +112,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -136,7 +144,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -145,7 +156,7 @@ class TestDatasetPreparation:
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
|
||||
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
||||
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||
tmp_ds_dir.mkdir()
|
||||
@@ -171,7 +182,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -206,7 +220,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -235,7 +252,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -264,7 +284,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -286,7 +309,8 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" not in train_dataset.features
|
||||
@@ -318,7 +342,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -342,13 +369,16 @@ class TestDatasetPreparation:
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset:
|
||||
with patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" not in train_dataset.features
|
||||
@@ -393,16 +423,18 @@ class TestDatasetPreparation:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
"axolotl.utils.data.shared.load_dataset_with_config"
|
||||
) as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, prepared_path
|
||||
)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH",
|
||||
str(prepared_path),
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -437,7 +469,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
|
||||
Reference in New Issue
Block a user