diff --git a/tests/prompt_strategies/test_alpacha.py b/tests/prompt_strategies/test_alpacha.py index ba2098938..882307d69 100644 --- a/tests/prompt_strategies/test_alpacha.py +++ b/tests/prompt_strategies/test_alpacha.py @@ -15,7 +15,11 @@ from axolotl.prompters import AlpacaPrompter def fixture_alpacha_dataset(): return Dataset.from_list( [ - {"instruction": "Evaluate this sentence for spelling and grammar mistakes", "input": "He finnished his meal and left the resturant", "output": "He finished his meal and left the restaurant."} + { + "instruction": "Evaluate this sentence for spelling and grammar mistakes", + "input": "He finnished his meal and left the resturant", + "output": "He finished his meal and left the restaurant.", + } ] ) @@ -46,9 +50,7 @@ class TestAlpacha: def test_no_double_im_end(self, alpacha_dataset, tokenizer): strategy = AlpacaPromptTokenizingStrategy( - AlpacaPrompter( - prompt_style="chatml" - ), + AlpacaPrompter(prompt_style="chatml"), tokenizer, False, # train_on_inputs 2048, # sequence_len @@ -60,13 +62,93 @@ class TestAlpacha: input_ids = dataset_wrapper[0]["input_ids"] - assert input_ids == [1, 32001, 1587, 13, 20548, 336, 349, 396, 13126, 369, 13966, 264, 3638, 28725, 5881, 1360, 395, 396, 2787, 369, 5312, 3629, 2758, 28723, 12018, 264, 2899, 369, 6582, 1999, 2691, 274, 272, 2159, 28723, 32000, 28705, 13, 32001, 2188, 13, 16627, 11931, 456, 12271, 354, 668, 3572, 304, 18756, 3479, 17179, 13, 2428, 854, 28711, 1497, 516, 11314, 304, 1749, 272, 1846, 324, 440, 32000, 28705, 13, 32001, 13892, 13, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000] - + assert input_ids == [ + 1, + 32001, + 1587, + 13, + 20548, + 336, + 349, + 396, + 13126, + 369, + 13966, + 264, + 3638, + 28725, + 5881, + 1360, + 395, + 396, + 2787, + 369, + 5312, + 3629, + 2758, + 28723, + 12018, + 264, + 2899, + 369, + 6582, + 1999, + 2691, + 274, + 272, + 2159, + 28723, + 32000, + 28705, + 13, + 32001, + 2188, + 13, + 16627, + 11931, + 456, + 12271, + 354, + 668, + 3572, + 304, + 18756, + 3479, + 17179, + 13, + 2428, + 854, + 28711, + 1497, + 516, + 11314, + 304, + 1749, + 272, + 1846, + 324, + 440, + 32000, + 28705, + 13, + 32001, + 13892, + 13, + 650, + 5967, + 516, + 11314, + 304, + 1749, + 272, + 9926, + 28723, + 32000, + ] + def test_no_train_on_input(self, alpacha_dataset, tokenizer): strategy = AlpacaPromptTokenizingStrategy( - AlpacaPrompter( - prompt_style="chatml" - ), + AlpacaPrompter(prompt_style="chatml"), tokenizer, False, # train_on_inputs 2048, # sequence_len @@ -78,13 +160,93 @@ class TestAlpacha: labels = dataset_wrapper[0]["labels"] - assert labels == [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000] + assert labels == [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 650, + 5967, + 516, + 11314, + 304, + 1749, + 272, + 9926, + 28723, + 32000, + ] def test_w_train_on_input(self, alpacha_dataset, tokenizer): strategy = AlpacaPromptTokenizingStrategy( - AlpacaPrompter( - prompt_style="chatml" - ), + AlpacaPrompter(prompt_style="chatml"), tokenizer, True, # train_on_inputs 2048, # sequence_len @@ -96,4 +258,86 @@ class TestAlpacha: labels = dataset_wrapper[0]["labels"] - assert labels == [1, 32001, 1587, 13, 20548, 336, 349, 396, 13126, 369, 13966, 264, 3638, 28725, 5881, 1360, 395, 396, 2787, 369, 5312, 3629, 2758, 28723, 12018, 264, 2899, 369, 6582, 1999, 2691, 274, 272, 2159, 28723, 32000, 28705, 13, 32001, 2188, 13, 16627, 11931, 456, 12271, 354, 668, 3572, 304, 18756, 3479, 17179, 13, 2428, 854, 28711, 1497, 516, 11314, 304, 1749, 272, 1846, 324, 440, 32000, 28705, 13, 32001, 13892, 13, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000] \ No newline at end of file + assert labels == [ + 1, + 32001, + 1587, + 13, + 20548, + 336, + 349, + 396, + 13126, + 369, + 13966, + 264, + 3638, + 28725, + 5881, + 1360, + 395, + 396, + 2787, + 369, + 5312, + 3629, + 2758, + 28723, + 12018, + 264, + 2899, + 369, + 6582, + 1999, + 2691, + 274, + 272, + 2159, + 28723, + 32000, + 28705, + 13, + 32001, + 2188, + 13, + 16627, + 11931, + 456, + 12271, + 354, + 668, + 3572, + 304, + 18756, + 3479, + 17179, + 13, + 2428, + 854, + 28711, + 1497, + 516, + 11314, + 304, + 1749, + 272, + 1846, + 324, + 440, + 32000, + 28705, + 13, + 32001, + 13892, + 13, + 650, + 5967, + 516, + 11314, + 304, + 1749, + 272, + 9926, + 28723, + 32000, + ]