Streaming SFT support (#3101)
* working * fixes * deprecate --iterable; cleanup * pretrain_multipack_buffer_size -> streaming_multipack_buffer_size * improvements * tests * remove unused * docs, examples * nit * nit * add val_set_size validation * val * nit * min * coderabbito * cleanup * nit * add depr warning, cleanup * nit * fix test, fix quarto * fix * review comments * review comments * fix
This commit is contained in:
@@ -1,16 +1,11 @@
|
||||
"""Module for testing dataset sequence packing"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter
|
||||
from axolotl.train import setup_model_and_trainer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -35,43 +30,6 @@ class TestPacking(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
def test_increments_attention(self):
|
||||
prompter = AlpacaPrompter("chat")
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
dateset = load_dataset(
|
||||
"json",
|
||||
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
|
||||
)["train"]
|
||||
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
|
||||
|
||||
constant_len_dataset = ConstantLengthDataset(
|
||||
self.tokenizer,
|
||||
[dataset],
|
||||
seq_length=2048,
|
||||
)
|
||||
packed_dataset = Dataset.from_list(list(constant_len_dataset))
|
||||
example = packed_dataset[0]
|
||||
next_bos_index = (
|
||||
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
|
||||
) # add one since we sliced
|
||||
|
||||
# first example doesn't have mask reset
|
||||
assert example["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
assert example["attention_mask"][0] == 1
|
||||
assert example["position_ids"][0] == 0
|
||||
assert example["position_ids"][1] == 1
|
||||
|
||||
# but subsequent one does
|
||||
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
||||
assert example["attention_mask"][next_bos_index] == 2
|
||||
assert example["position_ids"][next_bos_index] == 0
|
||||
assert example["position_ids"][next_bos_index + 1] == 1
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_packing(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
|
||||
Reference in New Issue
Block a user