progress on streaming
This commit is contained in:
@@ -7,13 +7,13 @@ from typing import Any, Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from datasets import Dataset, IterableDataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.utils.data.rl import prepare_preference_datasets
|
||||
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
|
||||
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets, prepare_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.constants import (
|
||||
@@ -46,6 +46,24 @@ class TestDatasetPreparation:
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def streaming_dataset_fixture(self):
|
||||
"""Create a streaming dataset fixture for testing."""
|
||||
|
||||
def generator():
|
||||
yield {
|
||||
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
|
||||
"input": "He finnished his meal and left the resturant",
|
||||
"output": "He finished his meal and left the restaurant.",
|
||||
}
|
||||
yield {
|
||||
"instruction": "What is the capital of France?",
|
||||
"input": "",
|
||||
"output": "The capital of France is Paris.",
|
||||
}
|
||||
|
||||
return IterableDataset.from_generator(generator)
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
@enable_hf_offline
|
||||
def test_load_hub(self, tokenizer):
|
||||
@@ -486,3 +504,45 @@ class TestDatasetPreparation:
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
def test_streaming_sft_dataset(self, tokenizer, streaming_dataset_fixture):
|
||||
"""Test streaming SFT dataset preparation with IterableDataset."""
|
||||
with patch("axolotl.utils.data.sft.load_dataset_with_config") as mock_load:
|
||||
mock_load.return_value = streaming_dataset_fixture
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 256,
|
||||
"streaming": True,
|
||||
"max_steps": 100, # Required for streaming datasets
|
||||
"datasets": [
|
||||
{
|
||||
"path": "dummy/path",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
|
||||
cfg, tokenizer
|
||||
)
|
||||
|
||||
# Verify it returns an IterableDataset
|
||||
assert isinstance(train_dataset, IterableDataset)
|
||||
assert eval_dataset is None # No eval split for streaming
|
||||
assert total_num_steps == 100 # Should use max_steps
|
||||
assert len(prompters) == 1
|
||||
|
||||
# Test that we can iterate through the dataset
|
||||
sample_count = 0
|
||||
for sample in train_dataset:
|
||||
assert "input_ids" in sample
|
||||
assert "attention_mask" in sample
|
||||
assert "labels" in sample
|
||||
sample_count += 1
|
||||
if sample_count >= 2: # Just test first few samples
|
||||
break
|
||||
|
||||
assert sample_count == 2
|
||||
|
||||
Reference in New Issue
Block a user