From 23a9fcb0a7610db37c47edf56d01c82c705db05f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 18 Feb 2025 16:08:40 -0500 Subject: [PATCH] make sure chatml dpo dataset loading works (#2333) --- tests/prompt_strategies/conftest.py | 6 +++ tests/prompt_strategies/test_dpo_chatml.py | 61 ++++++++++++++++++++++ 2 files changed, 67 insertions(+) create mode 100644 tests/prompt_strategies/test_dpo_chatml.py diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py index 9864a6fec..a7e417516 100644 --- a/tests/prompt_strategies/conftest.py +++ b/tests/prompt_strategies/conftest.py @@ -125,6 +125,12 @@ def fixture_llama3_tokenizer(): return tokenizer +@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True) +def fixture_smollm2_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M") + return tokenizer + + @pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True) def fixture_mistralv03_tokenizer(): tokenizer = AutoTokenizer.from_pretrained( diff --git a/tests/prompt_strategies/test_dpo_chatml.py b/tests/prompt_strategies/test_dpo_chatml.py new file mode 100644 index 000000000..34c29275b --- /dev/null +++ b/tests/prompt_strategies/test_dpo_chatml.py @@ -0,0 +1,61 @@ +""" +Tests for loading DPO preference datasets with chatml formatting +""" +import unittest + +import pytest + +from axolotl.prompt_strategies.dpo import load as load_dpo +from axolotl.utils.data.rl import load_prepare_preference_datasets +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="minimal_dpo_cfg") +def fixture_cfg(): + return DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_config": "HuggingFaceTB/SmolLM2-135M", + "rl": "dpo", + "learning_rate": 0.000001, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "sequence_len": 2048, + } + ) + + +class TestDPOChatml: + """ + Test loading DPO preference datasets with chatml formatting + """ + + def test_default(self, minimal_dpo_cfg): + cfg = DictDefault( + { + "datasets": [ + { + "path": "argilla/distilabel-intel-orca-dpo-pairs", + "type": "chatml", + "split": "train[:1%]", + } + ] + } + | minimal_dpo_cfg + ) + + # test that dpo.load works + load_dpo("chatml", cfg) + # now actually load the datasets with the strategy + train_ds, _ = load_prepare_preference_datasets(cfg) + assert train_ds[0]["prompt"].startswith("<|im_start|>") + assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n") + assert "chosen" in train_ds[0] + assert "rejected" in train_ds[0] + + +if __name__ == "__main__": + unittest.main()