Streaming SFT support (#3101)

* working

* fixes

* deprecate --iterable; cleanup

* pretrain_multipack_buffer_size -> streaming_multipack_buffer_size

* improvements

* tests

* remove unused

* docs, examples

* nit

* nit

* add val_set_size validation

* val

* nit

* min

* coderabbito

* cleanup

* nit

* add depr warning, cleanup

* nit

* fix test, fix quarto

* fix

* review comments

* review comments

* fix
This commit is contained in:
Dan Saunders
2025-09-02 12:08:44 -04:00
committed by GitHub
parent 0094a2d744
commit 231a67e70b
24 changed files with 849 additions and 283 deletions

View File

@@ -25,7 +25,7 @@ def min_cfg(temp_dir):
"liger_rms_norm": True,
"liger_glu_activation": True,
"torch_compile": True,
"chat_template": "llama3",
"chat_template": "qwen3",
"kd_trainer": True,
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,

View File

@@ -0,0 +1,73 @@
"""E2E tests for streaming dataset functionality"""
# pylint: disable=duplicate-code
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard
class TestStreamingDatasets:
"""Test case for streaming datasets"""
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_streaming_dataset(self, temp_dir, sample_packing):
"""Test streaming datasets"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": sample_packing,
"pretrain_multipack_attn": sample_packing,
"streaming_multipack_buffer_size": 10000,
"dataset_processes": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
# Streaming config
"streaming": True,
"max_steps": 3,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
# Verify training actually happened by checking loss decrease
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
3.0,
"Train Loss (%s) is too high",
)

View File

@@ -6,7 +6,7 @@ import unittest
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_pretraining, md5
from axolotl.utils.data import encode_streaming, md5
from tests.hf_offline_utils import enable_hf_offline
@@ -39,7 +39,7 @@ class TestEncodePretraining(unittest.TestCase):
"hello, hello",
]
}
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
result = encode_streaming(examples, self.tokenizer, self.max_tokens)
self.assertEqual(len(result["input_ids"]), 3)

View File

@@ -1,16 +1,11 @@
"""Module for testing dataset sequence packing"""
import unittest
from pathlib import Path
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
from axolotl.train import setup_model_and_trainer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@@ -35,43 +30,6 @@ class TestPacking(unittest.TestCase):
}
)
def test_increments_attention(self):
prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
dateset = load_dataset(
"json",
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
)["train"]
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
constant_len_dataset = ConstantLengthDataset(
self.tokenizer,
[dataset],
seq_length=2048,
)
packed_dataset = Dataset.from_list(list(constant_len_dataset))
example = packed_dataset[0]
next_bos_index = (
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
) # add one since we sliced
# first example doesn't have mask reset
assert example["input_ids"][0] == self.tokenizer.bos_token_id
assert example["attention_mask"][0] == 1
assert example["position_ids"][0] == 0
assert example["position_ids"][1] == 1
# but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 2
assert example["position_ids"][next_bos_index] == 0
assert example["position_ids"][next_bos_index + 1] == 1
@with_temp_dir
def test_lora_packing(self, temp_dir):
cfg = DictDefault(

View File

@@ -9,7 +9,7 @@ import torch
from datasets import IterableDataset
from torch.utils.data import DataLoader
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
from axolotl.utils.data import get_dataset_wrapper, wrap_streaming_dataset
from axolotl.utils.dict import DictDefault
@@ -77,14 +77,11 @@ class TestPretrainingPacking:
)
original_bsz = cfg.micro_batch_size
train_dataset = wrap_pretraining_dataset(
train_dataset = wrap_streaming_dataset(
dataset,
tokenizer_huggyllama,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
)
trainer_loader = DataLoader(

238
tests/test_streaming.py Normal file
View File

@@ -0,0 +1,238 @@
"""Test streaming configuration and data loading functionality."""
import unittest
from unittest.mock import Mock, patch
from datasets import IterableDataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.data.sft import (
_prepare_streaming_dataset,
prepare_datasets,
)
from axolotl.utils.config import validate_config
class TestStreamingConfig(unittest.TestCase):
"""Test streaming configuration and deprecation handling."""
def test_streaming_multipack_buffer_size_deprecation(self):
"""Test that pretrain_multipack_buffer_size is properly deprecated."""
# Test with old config name
cfg_old = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"pretrain_multipack_buffer_size": 5000,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.0001,
}
)
with self.assertLogs("axolotl.utils.schemas.validation", level="WARNING") as cm:
validated_cfg = validate_config(cfg_old)
self.assertIn("pretrain_multipack_buffer_size` is deprecated", cm.output[0])
self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 5000)
self.assertIsNone(
getattr(validated_cfg, "pretrain_multipack_buffer_size", None)
)
def test_streaming_multipack_buffer_size_new(self):
"""Test that new streaming_multipack_buffer_size works correctly."""
cfg_new = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"streaming_multipack_buffer_size": 7000,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.0001,
}
)
validated_cfg = validate_config(cfg_new)
self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 7000)
def test_both_buffer_sizes_raises_error(self):
"""Test that having both old and new buffer size configs raises an error."""
cfg_both = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"pretrain_multipack_buffer_size": 5000,
"streaming_multipack_buffer_size": 7000,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.0001,
}
)
with self.assertRaises(ValueError) as cm:
validate_config(cfg_both)
self.assertIn("both are set", str(cm.exception))
class TestStreamingDatasetPreparation(unittest.TestCase):
"""Test dataset preparation with streaming configuration."""
def setUp(self):
self.tokenizer = Mock()
self.tokenizer.pad_token_id = 0
self.tokenizer.eos_token_id = 1
@patch("axolotl.utils.data.sft._prepare_streaming_dataset")
def test_prepare_datasets_with_streaming_true(self, mock_prepare_streaming):
"""Test that streaming=True triggers streaming dataset preparation."""
cfg = DictDefault(
{
"streaming": True,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
}
)
mock_prepare_streaming.return_value = (Mock(), None, 100, [])
prepare_datasets(cfg, self.tokenizer)
mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)
@patch("axolotl.utils.data.sft._prepare_streaming_dataset")
def test_prepare_datasets_with_pretraining_dataset(self, mock_prepare_streaming):
"""Test that pretraining_dataset triggers streaming dataset preparation."""
cfg = DictDefault(
{
"pretraining_dataset": "test/dataset",
}
)
mock_prepare_streaming.return_value = (Mock(), None, 100, [])
prepare_datasets(cfg, self.tokenizer)
mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)
@patch("axolotl.utils.data.sft._prepare_standard_dataset")
def test_prepare_datasets_without_streaming(self, mock_prepare_standard):
"""Test that without streaming, standard dataset preparation is used."""
cfg = DictDefault(
{
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
}
)
mock_prepare_standard.return_value = (Mock(), None, 100, [])
prepare_datasets(cfg, self.tokenizer)
mock_prepare_standard.assert_called_once_with(cfg, self.tokenizer, None)
class TestStreamingWithSamplePacking(unittest.TestCase):
"""Test streaming dataset preparation with sample packing."""
def setUp(self):
self.tokenizer = Mock()
self.tokenizer.pad_token_id = 0
self.tokenizer.eos_token_id = 1
@patch("axolotl.utils.data.sft._load_streaming_dataset")
def test_streaming_sft_with_sample_packing_sets_split(self, mock_load_streaming):
"""Test that streaming SFT with sample_packing sets default split."""
cfg = DictDefault(
{
"streaming": True,
"sample_packing": True,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
}
)
mock_load_streaming.return_value = Mock(spec=IterableDataset)
with patch("axolotl.utils.data.sft._load_and_prepare_datasets"):
_prepare_streaming_dataset(cfg, self.tokenizer, None)
# Check that the dataset config has split set to 'train'
call_args = mock_load_streaming.call_args
dataset_config = call_args[0][0]
self.assertEqual(dataset_config.split, "train")
def test_multipack_attn_forced_true_for_sft(self):
"""Test that multipack_attn is forced to True for SFT with sample packing."""
from axolotl.utils.data.streaming import wrap_streaming_dataset
cfg = DictDefault(
{
"sample_packing": True,
"pretrain_multipack_attn": False, # Should be overridden for SFT
"pretraining_dataset": None, # This makes it SFT
"sequence_len": 256,
"micro_batch_size": 1,
"streaming_multipack_buffer_size": 1000,
"seed": 42,
}
)
mock_dataset = Mock()
mock_dataset.features = None # For streaming datasets
mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator
mock_dataset.map = Mock(return_value=mock_dataset)
mock_ds_wrapper = Mock()
with patch(
"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq"
) as mock_collator:
with patch("axolotl.utils.data.streaming.encode_packed_streaming"):
wrap_streaming_dataset(
mock_dataset, self.tokenizer, cfg, mock_ds_wrapper
)
# Check that multipack_attn=True was used in the collator
mock_collator.assert_called_once()
call_kwargs = mock_collator.call_args[1]
self.assertTrue(call_kwargs["multipack_attn"])
def test_multipack_attn_respects_config_for_pretraining(self):
"""Test that multipack_attn respects config for pretraining datasets."""
from axolotl.utils.data.streaming import wrap_streaming_dataset
cfg = DictDefault(
{
"sample_packing": True,
"pretrain_multipack_attn": False, # Should be respected for pretraining
"pretraining_dataset": "test/dataset", # This makes it pretraining
"sequence_len": 256,
"micro_batch_size": 1,
"streaming_multipack_buffer_size": 1000,
"seed": 42,
}
)
mock_dataset = Mock()
mock_dataset.features = None # For streaming datasets
mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator
mock_dataset.map = Mock(return_value=mock_dataset)
mock_ds_wrapper = Mock()
with patch(
"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq"
) as mock_collator:
with patch("axolotl.utils.data.streaming.encode_packed_streaming"):
wrap_streaming_dataset(
mock_dataset, self.tokenizer, cfg, mock_ds_wrapper
)
# Check that multipack_attn=False was used (respecting config)
mock_collator.assert_called_once()
call_kwargs = mock_collator.call_args[1]
self.assertFalse(call_kwargs["multipack_attn"])
if __name__ == "__main__":
unittest.main()