Files
axolotl/tests/test_streaming.py
Dan Saunders 1b53c49e1a text diffusion training plugin (#3067)
* diffusion training plugin

* cleanup

* nits

* fixes + improvements

* add back in reinit_weights (clobbered?); masking / pretrain fixes

* nits

* cleanup; tests draft

* sample generation, tests fixes

* fixes

* nits

* add inference support; add auto-mask token support

* nits

* nits

* progress

* simplify logging

* lint

* prefix args with diffusion_

* coderabbito

* tests fix

* nit

* nits

* cleanup + nits

* nits

* fix SFT sample gen

* fixes

* fix

* comments

* comments

* lint

* reward model lora fix

* cleanup; fix pretraining_dataset case

* gradio inference

* update cfgs

* update cfgs

* train, generation parity, cleanup

* fix

* simplify

* test

* test fix
2025-09-10 20:27:00 -04:00

239 lines
9.0 KiB
Python

"""Test streaming configuration and data loading functionality."""
import unittest
from unittest.mock import Mock, patch
from datasets import IterableDataset
from axolotl.utils.config import validate_config
from axolotl.utils.data.sft import (
_prepare_streaming_dataset,
prepare_datasets,
)
from axolotl.utils.dict import DictDefault
class TestStreamingConfig(unittest.TestCase):
"""Test streaming configuration and deprecation handling."""
def test_streaming_multipack_buffer_size_deprecation(self):
"""Test that pretrain_multipack_buffer_size is properly deprecated."""
# Test with old config name
cfg_old = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"pretrain_multipack_buffer_size": 5000,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.0001,
}
)
with self.assertLogs("axolotl.utils.schemas.validation", level="WARNING") as cm:
validated_cfg = validate_config(cfg_old)
self.assertIn("pretrain_multipack_buffer_size` is deprecated", cm.output[0])
self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 5000)
self.assertIsNone(
getattr(validated_cfg, "pretrain_multipack_buffer_size", None)
)
def test_streaming_multipack_buffer_size_new(self):
"""Test that new streaming_multipack_buffer_size works correctly."""
cfg_new = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"streaming_multipack_buffer_size": 7000,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.0001,
}
)
validated_cfg = validate_config(cfg_new)
self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 7000)
def test_both_buffer_sizes_raises_error(self):
"""Test that having both old and new buffer size configs raises an error."""
cfg_both = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"pretrain_multipack_buffer_size": 5000,
"streaming_multipack_buffer_size": 7000,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.0001,
}
)
with self.assertRaises(ValueError) as cm:
validate_config(cfg_both)
self.assertIn("both are set", str(cm.exception))
class TestStreamingDatasetPreparation(unittest.TestCase):
"""Test dataset preparation with streaming configuration."""
def setUp(self):
self.tokenizer = Mock()
self.tokenizer.pad_token_id = 0
self.tokenizer.eos_token_id = 1
@patch("axolotl.utils.data.sft._prepare_streaming_dataset")
def test_prepare_datasets_with_streaming_true(self, mock_prepare_streaming):
"""Test that streaming=True triggers streaming dataset preparation."""
cfg = DictDefault(
{
"streaming": True,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
}
)
mock_prepare_streaming.return_value = (Mock(), None, 100, [])
prepare_datasets(cfg, self.tokenizer)
mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)
@patch("axolotl.utils.data.sft._prepare_streaming_dataset")
def test_prepare_datasets_with_pretraining_dataset(self, mock_prepare_streaming):
"""Test that pretraining_dataset triggers streaming dataset preparation."""
cfg = DictDefault(
{
"pretraining_dataset": "test/dataset",
}
)
mock_prepare_streaming.return_value = (Mock(), None, 100, [])
prepare_datasets(cfg, self.tokenizer)
mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)
@patch("axolotl.utils.data.sft._prepare_standard_dataset")
def test_prepare_datasets_without_streaming(self, mock_prepare_standard):
"""Test that without streaming, standard dataset preparation is used."""
cfg = DictDefault(
{
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
}
)
mock_prepare_standard.return_value = (Mock(), None, 100, [])
prepare_datasets(cfg, self.tokenizer)
mock_prepare_standard.assert_called_once_with(cfg, self.tokenizer, None)
class TestStreamingWithSamplePacking(unittest.TestCase):
"""Test streaming dataset preparation with sample packing."""
def setUp(self):
self.tokenizer = Mock()
self.tokenizer.pad_token_id = 0
self.tokenizer.eos_token_id = 1
@patch("axolotl.utils.data.sft._load_streaming_dataset")
def test_streaming_sft_with_sample_packing_sets_split(self, mock_load_streaming):
"""Test that streaming SFT with sample_packing sets default split."""
cfg = DictDefault(
{
"streaming": True,
"sample_packing": True,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
}
)
mock_load_streaming.return_value = Mock(spec=IterableDataset)
with patch("axolotl.utils.data.sft._load_and_prepare_datasets"):
_prepare_streaming_dataset(cfg, self.tokenizer, None)
# Check that the dataset config has split set to 'train'
call_args = mock_load_streaming.call_args
dataset_config = call_args[0][0]
self.assertEqual(dataset_config.split, "train")
def test_multipack_attn_forced_true_for_sft(self):
"""Test that multipack_attn is forced to True for SFT with sample packing."""
from axolotl.utils.data.streaming import wrap_streaming_dataset
cfg = DictDefault(
{
"sample_packing": True,
"pretrain_multipack_attn": False, # Should be overridden for SFT
"pretraining_dataset": None, # This makes it SFT
"sequence_len": 256,
"micro_batch_size": 1,
"streaming_multipack_buffer_size": 1000,
"seed": 42,
}
)
mock_dataset = Mock()
mock_dataset.features = None # For streaming datasets
mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator
mock_dataset.map = Mock(return_value=mock_dataset)
mock_ds_wrapper = Mock()
with patch(
"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq"
) as mock_collator:
with patch("axolotl.utils.data.streaming.encode_packed_streaming"):
wrap_streaming_dataset(
mock_dataset, self.tokenizer, cfg, mock_ds_wrapper
)
# Check that multipack_attn=True was used in the collator
mock_collator.assert_called_once()
call_kwargs = mock_collator.call_args[1]
self.assertTrue(call_kwargs["multipack_attn"])
def test_multipack_attn_respects_config_for_pretraining(self):
"""Test that multipack_attn respects config for pretraining datasets."""
from axolotl.utils.data.streaming import wrap_streaming_dataset
cfg = DictDefault(
{
"sample_packing": True,
"pretrain_multipack_attn": False, # Should be respected for pretraining
"pretraining_dataset": "test/dataset", # This makes it pretraining
"sequence_len": 256,
"micro_batch_size": 1,
"streaming_multipack_buffer_size": 1000,
"seed": 42,
}
)
mock_dataset = Mock()
mock_dataset.features = None # For streaming datasets
mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator
mock_dataset.map = Mock(return_value=mock_dataset)
mock_ds_wrapper = Mock()
with patch(
"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq"
) as mock_collator:
with patch("axolotl.utils.data.streaming.encode_packed_streaming"):
wrap_streaming_dataset(
mock_dataset, self.tokenizer, cfg, mock_ds_wrapper
)
# Check that multipack_attn=False was used (respecting config)
mock_collator.assert_called_once()
call_kwargs = mock_collator.call_args[1]
self.assertFalse(call_kwargs["multipack_attn"])
if __name__ == "__main__":
unittest.main()