Streaming SFT support (#3101)
* working * fixes * deprecate --iterable; cleanup * pretrain_multipack_buffer_size -> streaming_multipack_buffer_size * improvements * tests * remove unused * docs, examples * nit * nit * add val_set_size validation * val * nit * min * coderabbito * cleanup * nit * add depr warning, cleanup * nit * fix test, fix quarto * fix * review comments * review comments * fix
This commit is contained in:
238
tests/test_streaming.py
Normal file
238
tests/test_streaming.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Test streaming configuration and data loading functionality."""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from datasets import IterableDataset
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.data.sft import (
|
||||
_prepare_streaming_dataset,
|
||||
prepare_datasets,
|
||||
)
|
||||
from axolotl.utils.config import validate_config
|
||||
|
||||
|
||||
class TestStreamingConfig(unittest.TestCase):
|
||||
"""Test streaming configuration and deprecation handling."""
|
||||
|
||||
def test_streaming_multipack_buffer_size_deprecation(self):
|
||||
"""Test that pretrain_multipack_buffer_size is properly deprecated."""
|
||||
# Test with old config name
|
||||
cfg_old = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"pretrain_multipack_buffer_size": 5000,
|
||||
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
||||
"sequence_len": 256,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
with self.assertLogs("axolotl.utils.schemas.validation", level="WARNING") as cm:
|
||||
validated_cfg = validate_config(cfg_old)
|
||||
self.assertIn("pretrain_multipack_buffer_size` is deprecated", cm.output[0])
|
||||
|
||||
self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 5000)
|
||||
self.assertIsNone(
|
||||
getattr(validated_cfg, "pretrain_multipack_buffer_size", None)
|
||||
)
|
||||
|
||||
def test_streaming_multipack_buffer_size_new(self):
|
||||
"""Test that new streaming_multipack_buffer_size works correctly."""
|
||||
cfg_new = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"streaming_multipack_buffer_size": 7000,
|
||||
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
||||
"sequence_len": 256,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
validated_cfg = validate_config(cfg_new)
|
||||
self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 7000)
|
||||
|
||||
def test_both_buffer_sizes_raises_error(self):
|
||||
"""Test that having both old and new buffer size configs raises an error."""
|
||||
cfg_both = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"pretrain_multipack_buffer_size": 5000,
|
||||
"streaming_multipack_buffer_size": 7000,
|
||||
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
||||
"sequence_len": 256,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
validate_config(cfg_both)
|
||||
self.assertIn("both are set", str(cm.exception))
|
||||
|
||||
|
||||
class TestStreamingDatasetPreparation(unittest.TestCase):
|
||||
"""Test dataset preparation with streaming configuration."""
|
||||
|
||||
def setUp(self):
|
||||
self.tokenizer = Mock()
|
||||
self.tokenizer.pad_token_id = 0
|
||||
self.tokenizer.eos_token_id = 1
|
||||
|
||||
@patch("axolotl.utils.data.sft._prepare_streaming_dataset")
|
||||
def test_prepare_datasets_with_streaming_true(self, mock_prepare_streaming):
|
||||
"""Test that streaming=True triggers streaming dataset preparation."""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"streaming": True,
|
||||
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
||||
}
|
||||
)
|
||||
|
||||
mock_prepare_streaming.return_value = (Mock(), None, 100, [])
|
||||
|
||||
prepare_datasets(cfg, self.tokenizer)
|
||||
|
||||
mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)
|
||||
|
||||
@patch("axolotl.utils.data.sft._prepare_streaming_dataset")
|
||||
def test_prepare_datasets_with_pretraining_dataset(self, mock_prepare_streaming):
|
||||
"""Test that pretraining_dataset triggers streaming dataset preparation."""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"pretraining_dataset": "test/dataset",
|
||||
}
|
||||
)
|
||||
|
||||
mock_prepare_streaming.return_value = (Mock(), None, 100, [])
|
||||
|
||||
prepare_datasets(cfg, self.tokenizer)
|
||||
|
||||
mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)
|
||||
|
||||
@patch("axolotl.utils.data.sft._prepare_standard_dataset")
|
||||
def test_prepare_datasets_without_streaming(self, mock_prepare_standard):
|
||||
"""Test that without streaming, standard dataset preparation is used."""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
||||
}
|
||||
)
|
||||
|
||||
mock_prepare_standard.return_value = (Mock(), None, 100, [])
|
||||
|
||||
prepare_datasets(cfg, self.tokenizer)
|
||||
|
||||
mock_prepare_standard.assert_called_once_with(cfg, self.tokenizer, None)
|
||||
|
||||
|
||||
class TestStreamingWithSamplePacking(unittest.TestCase):
|
||||
"""Test streaming dataset preparation with sample packing."""
|
||||
|
||||
def setUp(self):
|
||||
self.tokenizer = Mock()
|
||||
self.tokenizer.pad_token_id = 0
|
||||
self.tokenizer.eos_token_id = 1
|
||||
|
||||
@patch("axolotl.utils.data.sft._load_streaming_dataset")
|
||||
def test_streaming_sft_with_sample_packing_sets_split(self, mock_load_streaming):
|
||||
"""Test that streaming SFT with sample_packing sets default split."""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"streaming": True,
|
||||
"sample_packing": True,
|
||||
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
||||
"sequence_len": 256,
|
||||
"micro_batch_size": 1,
|
||||
}
|
||||
)
|
||||
|
||||
mock_load_streaming.return_value = Mock(spec=IterableDataset)
|
||||
|
||||
with patch("axolotl.utils.data.sft._load_and_prepare_datasets"):
|
||||
_prepare_streaming_dataset(cfg, self.tokenizer, None)
|
||||
|
||||
# Check that the dataset config has split set to 'train'
|
||||
call_args = mock_load_streaming.call_args
|
||||
dataset_config = call_args[0][0]
|
||||
self.assertEqual(dataset_config.split, "train")
|
||||
|
||||
def test_multipack_attn_forced_true_for_sft(self):
|
||||
"""Test that multipack_attn is forced to True for SFT with sample packing."""
|
||||
from axolotl.utils.data.streaming import wrap_streaming_dataset
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"pretrain_multipack_attn": False, # Should be overridden for SFT
|
||||
"pretraining_dataset": None, # This makes it SFT
|
||||
"sequence_len": 256,
|
||||
"micro_batch_size": 1,
|
||||
"streaming_multipack_buffer_size": 1000,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.features = None # For streaming datasets
|
||||
mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator
|
||||
mock_dataset.map = Mock(return_value=mock_dataset)
|
||||
mock_ds_wrapper = Mock()
|
||||
|
||||
with patch(
|
||||
"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq"
|
||||
) as mock_collator:
|
||||
with patch("axolotl.utils.data.streaming.encode_packed_streaming"):
|
||||
wrap_streaming_dataset(
|
||||
mock_dataset, self.tokenizer, cfg, mock_ds_wrapper
|
||||
)
|
||||
|
||||
# Check that multipack_attn=True was used in the collator
|
||||
mock_collator.assert_called_once()
|
||||
call_kwargs = mock_collator.call_args[1]
|
||||
self.assertTrue(call_kwargs["multipack_attn"])
|
||||
|
||||
def test_multipack_attn_respects_config_for_pretraining(self):
|
||||
"""Test that multipack_attn respects config for pretraining datasets."""
|
||||
from axolotl.utils.data.streaming import wrap_streaming_dataset
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"pretrain_multipack_attn": False, # Should be respected for pretraining
|
||||
"pretraining_dataset": "test/dataset", # This makes it pretraining
|
||||
"sequence_len": 256,
|
||||
"micro_batch_size": 1,
|
||||
"streaming_multipack_buffer_size": 1000,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.features = None # For streaming datasets
|
||||
mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator
|
||||
mock_dataset.map = Mock(return_value=mock_dataset)
|
||||
mock_ds_wrapper = Mock()
|
||||
|
||||
with patch(
|
||||
"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq"
|
||||
) as mock_collator:
|
||||
with patch("axolotl.utils.data.streaming.encode_packed_streaming"):
|
||||
wrap_streaming_dataset(
|
||||
mock_dataset, self.tokenizer, cfg, mock_ds_wrapper
|
||||
)
|
||||
|
||||
# Check that multipack_attn=False was used (respecting config)
|
||||
mock_collator.assert_called_once()
|
||||
call_kwargs = mock_collator.call_args[1]
|
||||
self.assertFalse(call_kwargs["multipack_attn"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user