* diffusion training plugin * cleanup * nits * fixes + improvements * add back in reinit_weights (clobbered?); masking / pretrain fixes * nits * cleanup; tests draft * sample generation, tests fixes * fixes * nits * add inference support; add auto-mask token support * nits * nits * progress * simplify logging * lint * prefix args with diffusion_ * coderabbito * tests fix * nit * nits * cleanup + nits * nits * fix SFT sample gen * fixes * fix * comments * comments * lint * reward model lora fix * cleanup; fix pretraining_dataset case * gradio inference * update cfgs * update cfgs * train, generation parity, cleanup * fix * simplify * test * test fix
239 lines
9.0 KiB
Python
239 lines
9.0 KiB
Python
"""Test streaming configuration and data loading functionality."""
|
|
|
|
import unittest
|
|
from unittest.mock import Mock, patch
|
|
|
|
from datasets import IterableDataset
|
|
|
|
from axolotl.utils.config import validate_config
|
|
from axolotl.utils.data.sft import (
|
|
_prepare_streaming_dataset,
|
|
prepare_datasets,
|
|
)
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
|
|
class TestStreamingConfig(unittest.TestCase):
|
|
"""Test streaming configuration and deprecation handling."""
|
|
|
|
def test_streaming_multipack_buffer_size_deprecation(self):
|
|
"""Test that pretrain_multipack_buffer_size is properly deprecated."""
|
|
# Test with old config name
|
|
cfg_old = DictDefault(
|
|
{
|
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
"pretrain_multipack_buffer_size": 5000,
|
|
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
|
"sequence_len": 256,
|
|
"micro_batch_size": 1,
|
|
"gradient_accumulation_steps": 1,
|
|
"learning_rate": 0.0001,
|
|
}
|
|
)
|
|
|
|
with self.assertLogs("axolotl.utils.schemas.validation", level="WARNING") as cm:
|
|
validated_cfg = validate_config(cfg_old)
|
|
self.assertIn("pretrain_multipack_buffer_size` is deprecated", cm.output[0])
|
|
|
|
self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 5000)
|
|
self.assertIsNone(
|
|
getattr(validated_cfg, "pretrain_multipack_buffer_size", None)
|
|
)
|
|
|
|
def test_streaming_multipack_buffer_size_new(self):
|
|
"""Test that new streaming_multipack_buffer_size works correctly."""
|
|
cfg_new = DictDefault(
|
|
{
|
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
"streaming_multipack_buffer_size": 7000,
|
|
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
|
"sequence_len": 256,
|
|
"micro_batch_size": 1,
|
|
"gradient_accumulation_steps": 1,
|
|
"learning_rate": 0.0001,
|
|
}
|
|
)
|
|
|
|
validated_cfg = validate_config(cfg_new)
|
|
self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 7000)
|
|
|
|
def test_both_buffer_sizes_raises_error(self):
|
|
"""Test that having both old and new buffer size configs raises an error."""
|
|
cfg_both = DictDefault(
|
|
{
|
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
"pretrain_multipack_buffer_size": 5000,
|
|
"streaming_multipack_buffer_size": 7000,
|
|
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
|
"sequence_len": 256,
|
|
"micro_batch_size": 1,
|
|
"gradient_accumulation_steps": 1,
|
|
"learning_rate": 0.0001,
|
|
}
|
|
)
|
|
|
|
with self.assertRaises(ValueError) as cm:
|
|
validate_config(cfg_both)
|
|
self.assertIn("both are set", str(cm.exception))
|
|
|
|
|
|
class TestStreamingDatasetPreparation(unittest.TestCase):
|
|
"""Test dataset preparation with streaming configuration."""
|
|
|
|
def setUp(self):
|
|
self.tokenizer = Mock()
|
|
self.tokenizer.pad_token_id = 0
|
|
self.tokenizer.eos_token_id = 1
|
|
|
|
@patch("axolotl.utils.data.sft._prepare_streaming_dataset")
|
|
def test_prepare_datasets_with_streaming_true(self, mock_prepare_streaming):
|
|
"""Test that streaming=True triggers streaming dataset preparation."""
|
|
cfg = DictDefault(
|
|
{
|
|
"streaming": True,
|
|
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
|
}
|
|
)
|
|
|
|
mock_prepare_streaming.return_value = (Mock(), None, 100, [])
|
|
|
|
prepare_datasets(cfg, self.tokenizer)
|
|
|
|
mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)
|
|
|
|
@patch("axolotl.utils.data.sft._prepare_streaming_dataset")
|
|
def test_prepare_datasets_with_pretraining_dataset(self, mock_prepare_streaming):
|
|
"""Test that pretraining_dataset triggers streaming dataset preparation."""
|
|
cfg = DictDefault(
|
|
{
|
|
"pretraining_dataset": "test/dataset",
|
|
}
|
|
)
|
|
|
|
mock_prepare_streaming.return_value = (Mock(), None, 100, [])
|
|
|
|
prepare_datasets(cfg, self.tokenizer)
|
|
|
|
mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)
|
|
|
|
@patch("axolotl.utils.data.sft._prepare_standard_dataset")
|
|
def test_prepare_datasets_without_streaming(self, mock_prepare_standard):
|
|
"""Test that without streaming, standard dataset preparation is used."""
|
|
cfg = DictDefault(
|
|
{
|
|
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
|
}
|
|
)
|
|
|
|
mock_prepare_standard.return_value = (Mock(), None, 100, [])
|
|
|
|
prepare_datasets(cfg, self.tokenizer)
|
|
|
|
mock_prepare_standard.assert_called_once_with(cfg, self.tokenizer, None)
|
|
|
|
|
|
class TestStreamingWithSamplePacking(unittest.TestCase):
|
|
"""Test streaming dataset preparation with sample packing."""
|
|
|
|
def setUp(self):
|
|
self.tokenizer = Mock()
|
|
self.tokenizer.pad_token_id = 0
|
|
self.tokenizer.eos_token_id = 1
|
|
|
|
@patch("axolotl.utils.data.sft._load_streaming_dataset")
|
|
def test_streaming_sft_with_sample_packing_sets_split(self, mock_load_streaming):
|
|
"""Test that streaming SFT with sample_packing sets default split."""
|
|
cfg = DictDefault(
|
|
{
|
|
"streaming": True,
|
|
"sample_packing": True,
|
|
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
|
"sequence_len": 256,
|
|
"micro_batch_size": 1,
|
|
}
|
|
)
|
|
|
|
mock_load_streaming.return_value = Mock(spec=IterableDataset)
|
|
|
|
with patch("axolotl.utils.data.sft._load_and_prepare_datasets"):
|
|
_prepare_streaming_dataset(cfg, self.tokenizer, None)
|
|
|
|
# Check that the dataset config has split set to 'train'
|
|
call_args = mock_load_streaming.call_args
|
|
dataset_config = call_args[0][0]
|
|
self.assertEqual(dataset_config.split, "train")
|
|
|
|
def test_multipack_attn_forced_true_for_sft(self):
|
|
"""Test that multipack_attn is forced to True for SFT with sample packing."""
|
|
from axolotl.utils.data.streaming import wrap_streaming_dataset
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"sample_packing": True,
|
|
"pretrain_multipack_attn": False, # Should be overridden for SFT
|
|
"pretraining_dataset": None, # This makes it SFT
|
|
"sequence_len": 256,
|
|
"micro_batch_size": 1,
|
|
"streaming_multipack_buffer_size": 1000,
|
|
"seed": 42,
|
|
}
|
|
)
|
|
|
|
mock_dataset = Mock()
|
|
mock_dataset.features = None # For streaming datasets
|
|
mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator
|
|
mock_dataset.map = Mock(return_value=mock_dataset)
|
|
mock_ds_wrapper = Mock()
|
|
|
|
with patch(
|
|
"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq"
|
|
) as mock_collator:
|
|
with patch("axolotl.utils.data.streaming.encode_packed_streaming"):
|
|
wrap_streaming_dataset(
|
|
mock_dataset, self.tokenizer, cfg, mock_ds_wrapper
|
|
)
|
|
|
|
# Check that multipack_attn=True was used in the collator
|
|
mock_collator.assert_called_once()
|
|
call_kwargs = mock_collator.call_args[1]
|
|
self.assertTrue(call_kwargs["multipack_attn"])
|
|
|
|
def test_multipack_attn_respects_config_for_pretraining(self):
|
|
"""Test that multipack_attn respects config for pretraining datasets."""
|
|
from axolotl.utils.data.streaming import wrap_streaming_dataset
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"sample_packing": True,
|
|
"pretrain_multipack_attn": False, # Should be respected for pretraining
|
|
"pretraining_dataset": "test/dataset", # This makes it pretraining
|
|
"sequence_len": 256,
|
|
"micro_batch_size": 1,
|
|
"streaming_multipack_buffer_size": 1000,
|
|
"seed": 42,
|
|
}
|
|
)
|
|
|
|
mock_dataset = Mock()
|
|
mock_dataset.features = None # For streaming datasets
|
|
mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator
|
|
mock_dataset.map = Mock(return_value=mock_dataset)
|
|
mock_ds_wrapper = Mock()
|
|
|
|
with patch(
|
|
"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq"
|
|
) as mock_collator:
|
|
with patch("axolotl.utils.data.streaming.encode_packed_streaming"):
|
|
wrap_streaming_dataset(
|
|
mock_dataset, self.tokenizer, cfg, mock_ds_wrapper
|
|
)
|
|
|
|
# Check that multipack_attn=False was used (respecting config)
|
|
mock_collator.assert_called_once()
|
|
call_kwargs = mock_collator.call_args[1]
|
|
self.assertFalse(call_kwargs["multipack_attn"])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|