Streaming SFT support (#3101)
* working * fixes * deprecate --iterable; cleanup * pretrain_multipack_buffer_size -> streaming_multipack_buffer_size * improvements * tests * remove unused * docs, examples * nit * nit * add val_set_size validation * val * nit * min * coderabbito * cleanup * nit * add depr warning, cleanup * nit * fix test, fix quarto * fix * review comments * review comments * fix
This commit is contained in:
@@ -25,7 +25,7 @@ def min_cfg(temp_dir):
|
||||
"liger_rms_norm": True,
|
||||
"liger_glu_activation": True,
|
||||
"torch_compile": True,
|
||||
"chat_template": "llama3",
|
||||
"chat_template": "qwen3",
|
||||
"kd_trainer": True,
|
||||
"kd_ce_alpha": 0.1,
|
||||
"kd_alpha": 0.9,
|
||||
|
||||
73
tests/e2e/test_streaming.py
Normal file
73
tests/e2e/test_streaming.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""E2E tests for streaming dataset functionality"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
|
||||
class TestStreamingDatasets:
|
||||
"""Test case for streaming datasets"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing",
|
||||
[True, False],
|
||||
)
|
||||
def test_streaming_dataset(self, temp_dir, sample_packing):
|
||||
"""Test streaming datasets"""
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": sample_packing,
|
||||
"pretrain_multipack_attn": sample_packing,
|
||||
"streaming_multipack_buffer_size": 10000,
|
||||
"dataset_processes": 1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
# Streaming config
|
||||
"streaming": True,
|
||||
"max_steps": 3,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
# Verify training actually happened by checking loss decrease
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
3.0,
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
@@ -6,7 +6,7 @@ import unittest
|
||||
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
from axolotl.utils.data import encode_pretraining, md5
|
||||
from axolotl.utils.data import encode_streaming, md5
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
@@ -39,7 +39,7 @@ class TestEncodePretraining(unittest.TestCase):
|
||||
"hello, hello",
|
||||
]
|
||||
}
|
||||
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
|
||||
result = encode_streaming(examples, self.tokenizer, self.max_tokens)
|
||||
|
||||
self.assertEqual(len(result["input_ids"]), 3)
|
||||
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
"""Module for testing dataset sequence packing"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter
|
||||
from axolotl.train import setup_model_and_trainer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -35,43 +30,6 @@ class TestPacking(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
def test_increments_attention(self):
|
||||
prompter = AlpacaPrompter("chat")
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
dateset = load_dataset(
|
||||
"json",
|
||||
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
|
||||
)["train"]
|
||||
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
|
||||
|
||||
constant_len_dataset = ConstantLengthDataset(
|
||||
self.tokenizer,
|
||||
[dataset],
|
||||
seq_length=2048,
|
||||
)
|
||||
packed_dataset = Dataset.from_list(list(constant_len_dataset))
|
||||
example = packed_dataset[0]
|
||||
next_bos_index = (
|
||||
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
|
||||
) # add one since we sliced
|
||||
|
||||
# first example doesn't have mask reset
|
||||
assert example["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
assert example["attention_mask"][0] == 1
|
||||
assert example["position_ids"][0] == 0
|
||||
assert example["position_ids"][1] == 1
|
||||
|
||||
# but subsequent one does
|
||||
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
||||
assert example["attention_mask"][next_bos_index] == 2
|
||||
assert example["position_ids"][next_bos_index] == 0
|
||||
assert example["position_ids"][next_bos_index + 1] == 1
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_packing(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
|
||||
@@ -9,7 +9,7 @@ import torch
|
||||
from datasets import IterableDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
|
||||
from axolotl.utils.data import get_dataset_wrapper, wrap_streaming_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -77,14 +77,11 @@ class TestPretrainingPacking:
|
||||
)
|
||||
|
||||
original_bsz = cfg.micro_batch_size
|
||||
train_dataset = wrap_pretraining_dataset(
|
||||
train_dataset = wrap_streaming_dataset(
|
||||
dataset,
|
||||
tokenizer_huggyllama,
|
||||
cfg,
|
||||
ds_wrapper_partial,
|
||||
max_tokens=cfg.sequence_len,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
|
||||
trainer_loader = DataLoader(
|
||||
|
||||
238
tests/test_streaming.py
Normal file
238
tests/test_streaming.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Test streaming configuration and data loading functionality."""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from datasets import IterableDataset
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.data.sft import (
|
||||
_prepare_streaming_dataset,
|
||||
prepare_datasets,
|
||||
)
|
||||
from axolotl.utils.config import validate_config
|
||||
|
||||
|
||||
class TestStreamingConfig(unittest.TestCase):
|
||||
"""Test streaming configuration and deprecation handling."""
|
||||
|
||||
def test_streaming_multipack_buffer_size_deprecation(self):
|
||||
"""Test that pretrain_multipack_buffer_size is properly deprecated."""
|
||||
# Test with old config name
|
||||
cfg_old = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"pretrain_multipack_buffer_size": 5000,
|
||||
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
||||
"sequence_len": 256,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
with self.assertLogs("axolotl.utils.schemas.validation", level="WARNING") as cm:
|
||||
validated_cfg = validate_config(cfg_old)
|
||||
self.assertIn("pretrain_multipack_buffer_size` is deprecated", cm.output[0])
|
||||
|
||||
self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 5000)
|
||||
self.assertIsNone(
|
||||
getattr(validated_cfg, "pretrain_multipack_buffer_size", None)
|
||||
)
|
||||
|
||||
def test_streaming_multipack_buffer_size_new(self):
|
||||
"""Test that new streaming_multipack_buffer_size works correctly."""
|
||||
cfg_new = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"streaming_multipack_buffer_size": 7000,
|
||||
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
||||
"sequence_len": 256,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
validated_cfg = validate_config(cfg_new)
|
||||
self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 7000)
|
||||
|
||||
def test_both_buffer_sizes_raises_error(self):
|
||||
"""Test that having both old and new buffer size configs raises an error."""
|
||||
cfg_both = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"pretrain_multipack_buffer_size": 5000,
|
||||
"streaming_multipack_buffer_size": 7000,
|
||||
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
||||
"sequence_len": 256,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
validate_config(cfg_both)
|
||||
self.assertIn("both are set", str(cm.exception))
|
||||
|
||||
|
||||
class TestStreamingDatasetPreparation(unittest.TestCase):
|
||||
"""Test dataset preparation with streaming configuration."""
|
||||
|
||||
def setUp(self):
|
||||
self.tokenizer = Mock()
|
||||
self.tokenizer.pad_token_id = 0
|
||||
self.tokenizer.eos_token_id = 1
|
||||
|
||||
@patch("axolotl.utils.data.sft._prepare_streaming_dataset")
|
||||
def test_prepare_datasets_with_streaming_true(self, mock_prepare_streaming):
|
||||
"""Test that streaming=True triggers streaming dataset preparation."""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"streaming": True,
|
||||
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
||||
}
|
||||
)
|
||||
|
||||
mock_prepare_streaming.return_value = (Mock(), None, 100, [])
|
||||
|
||||
prepare_datasets(cfg, self.tokenizer)
|
||||
|
||||
mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)
|
||||
|
||||
@patch("axolotl.utils.data.sft._prepare_streaming_dataset")
|
||||
def test_prepare_datasets_with_pretraining_dataset(self, mock_prepare_streaming):
|
||||
"""Test that pretraining_dataset triggers streaming dataset preparation."""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"pretraining_dataset": "test/dataset",
|
||||
}
|
||||
)
|
||||
|
||||
mock_prepare_streaming.return_value = (Mock(), None, 100, [])
|
||||
|
||||
prepare_datasets(cfg, self.tokenizer)
|
||||
|
||||
mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)
|
||||
|
||||
@patch("axolotl.utils.data.sft._prepare_standard_dataset")
|
||||
def test_prepare_datasets_without_streaming(self, mock_prepare_standard):
|
||||
"""Test that without streaming, standard dataset preparation is used."""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
||||
}
|
||||
)
|
||||
|
||||
mock_prepare_standard.return_value = (Mock(), None, 100, [])
|
||||
|
||||
prepare_datasets(cfg, self.tokenizer)
|
||||
|
||||
mock_prepare_standard.assert_called_once_with(cfg, self.tokenizer, None)
|
||||
|
||||
|
||||
class TestStreamingWithSamplePacking(unittest.TestCase):
|
||||
"""Test streaming dataset preparation with sample packing."""
|
||||
|
||||
def setUp(self):
|
||||
self.tokenizer = Mock()
|
||||
self.tokenizer.pad_token_id = 0
|
||||
self.tokenizer.eos_token_id = 1
|
||||
|
||||
@patch("axolotl.utils.data.sft._load_streaming_dataset")
|
||||
def test_streaming_sft_with_sample_packing_sets_split(self, mock_load_streaming):
|
||||
"""Test that streaming SFT with sample_packing sets default split."""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"streaming": True,
|
||||
"sample_packing": True,
|
||||
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
|
||||
"sequence_len": 256,
|
||||
"micro_batch_size": 1,
|
||||
}
|
||||
)
|
||||
|
||||
mock_load_streaming.return_value = Mock(spec=IterableDataset)
|
||||
|
||||
with patch("axolotl.utils.data.sft._load_and_prepare_datasets"):
|
||||
_prepare_streaming_dataset(cfg, self.tokenizer, None)
|
||||
|
||||
# Check that the dataset config has split set to 'train'
|
||||
call_args = mock_load_streaming.call_args
|
||||
dataset_config = call_args[0][0]
|
||||
self.assertEqual(dataset_config.split, "train")
|
||||
|
||||
def test_multipack_attn_forced_true_for_sft(self):
|
||||
"""Test that multipack_attn is forced to True for SFT with sample packing."""
|
||||
from axolotl.utils.data.streaming import wrap_streaming_dataset
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"pretrain_multipack_attn": False, # Should be overridden for SFT
|
||||
"pretraining_dataset": None, # This makes it SFT
|
||||
"sequence_len": 256,
|
||||
"micro_batch_size": 1,
|
||||
"streaming_multipack_buffer_size": 1000,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.features = None # For streaming datasets
|
||||
mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator
|
||||
mock_dataset.map = Mock(return_value=mock_dataset)
|
||||
mock_ds_wrapper = Mock()
|
||||
|
||||
with patch(
|
||||
"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq"
|
||||
) as mock_collator:
|
||||
with patch("axolotl.utils.data.streaming.encode_packed_streaming"):
|
||||
wrap_streaming_dataset(
|
||||
mock_dataset, self.tokenizer, cfg, mock_ds_wrapper
|
||||
)
|
||||
|
||||
# Check that multipack_attn=True was used in the collator
|
||||
mock_collator.assert_called_once()
|
||||
call_kwargs = mock_collator.call_args[1]
|
||||
self.assertTrue(call_kwargs["multipack_attn"])
|
||||
|
||||
def test_multipack_attn_respects_config_for_pretraining(self):
|
||||
"""Test that multipack_attn respects config for pretraining datasets."""
|
||||
from axolotl.utils.data.streaming import wrap_streaming_dataset
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"pretrain_multipack_attn": False, # Should be respected for pretraining
|
||||
"pretraining_dataset": "test/dataset", # This makes it pretraining
|
||||
"sequence_len": 256,
|
||||
"micro_batch_size": 1,
|
||||
"streaming_multipack_buffer_size": 1000,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.features = None # For streaming datasets
|
||||
mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator
|
||||
mock_dataset.map = Mock(return_value=mock_dataset)
|
||||
mock_ds_wrapper = Mock()
|
||||
|
||||
with patch(
|
||||
"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq"
|
||||
) as mock_collator:
|
||||
with patch("axolotl.utils.data.streaming.encode_packed_streaming"):
|
||||
wrap_streaming_dataset(
|
||||
mock_dataset, self.tokenizer, cfg, mock_ds_wrapper
|
||||
)
|
||||
|
||||
# Check that multipack_attn=False was used (respecting config)
|
||||
mock_collator.assert_called_once()
|
||||
call_kwargs = mock_collator.call_args[1]
|
||||
self.assertFalse(call_kwargs["multipack_attn"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user