make sure chatml dpo dataset loading works (#2333)
This commit is contained in:
@@ -125,6 +125,12 @@ def fixture_llama3_tokenizer():
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
|
||||||
|
def fixture_smollm2_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
||||||
def fixture_mistralv03_tokenizer():
|
def fixture_mistralv03_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
|||||||
61
tests/prompt_strategies/test_dpo_chatml.py
Normal file
61
tests/prompt_strategies/test_dpo_chatml.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""
|
||||||
|
Tests for loading DPO preference datasets with chatml formatting
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||||
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="minimal_dpo_cfg")
|
||||||
|
def fixture_cfg():
|
||||||
|
return DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"rl": "dpo",
|
||||||
|
"learning_rate": 0.000001,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"sequence_len": 2048,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDPOChatml:
|
||||||
|
"""
|
||||||
|
Test loading DPO preference datasets with chatml formatting
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_default(self, minimal_dpo_cfg):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "argilla/distilabel-intel-orca-dpo-pairs",
|
||||||
|
"type": "chatml",
|
||||||
|
"split": "train[:1%]",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
| minimal_dpo_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
# test that dpo.load works
|
||||||
|
load_dpo("chatml", cfg)
|
||||||
|
# now actually load the datasets with the strategy
|
||||||
|
train_ds, _ = load_prepare_preference_datasets(cfg)
|
||||||
|
assert train_ds[0]["prompt"].startswith("<|im_start|>")
|
||||||
|
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
|
||||||
|
assert "chosen" in train_ds[0]
|
||||||
|
assert "rejected" in train_ds[0]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user