* hf offline decorator for tests to workaround rate limits * fail quicker so we can see logs * try new cache name * limit files downloaded * phi mini predownload * offline decorator for phi tokenizer * handle meta llama 8b offline too * make sure to return fixtures if they are wrapped too * more fixes * more things offline * more offline things * fix the env var * fix the model name * handle gemma also * force reload of modules to recheck offline status * prefetch mistral too * use reset_sessions so hub picks up offline mode * more fixes * rename so it doesn't seem like a context manager * fix backoff * switch out tinyshakespeare dataset since it runs a py script to fetch data and doesn't work offline * include additional dataset * more fixes * more fixes * replace tiny shakespeaere dataset * skip some tests for now * use more robust check using snapshot download to determine if a dataset name is on the hub * typo for skip reason * use local_files_only * more fixtures * remove local only * use tiny shakespeare as pretrain dataset and streaming can't be offline even if precached * make sure fixtures aren't offline improve the offline reset try bumping version of datasets reorder reloading and setting prime a new cache run the tests now with fresh cache try with a static cache * now run all the ci again with hopefully a correct cache * skip wonky tests for now * skip wonky tests for now * handle offline mode for model card creation
66 lines
1.8 KiB
Python
66 lines
1.8 KiB
Python
"""
|
|
Tests for loading DPO preference datasets with chatml formatting
|
|
"""
|
|
|
|
import unittest
|
|
|
|
import pytest
|
|
from utils import enable_hf_offline
|
|
|
|
from axolotl.prompt_strategies.dpo import load as load_dpo
|
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
|
|
@pytest.fixture(name="minimal_dpo_cfg")
|
|
def fixture_cfg():
|
|
return DictDefault(
|
|
{
|
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
|
"rl": "dpo",
|
|
"learning_rate": 0.000001,
|
|
"micro_batch_size": 1,
|
|
"gradient_accumulation_steps": 1,
|
|
"special_tokens": {
|
|
"pad_token": "<|endoftext|>",
|
|
},
|
|
"sequence_len": 2048,
|
|
}
|
|
)
|
|
|
|
|
|
class TestDPOChatml:
|
|
"""
|
|
Test loading DPO preference datasets with chatml formatting
|
|
"""
|
|
|
|
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
|
@enable_hf_offline
|
|
def test_default(self, minimal_dpo_cfg):
|
|
cfg = DictDefault(
|
|
{
|
|
"datasets": [
|
|
{
|
|
"path": "argilla/distilabel-intel-orca-dpo-pairs",
|
|
"type": "chatml",
|
|
"split": "train[:1%]",
|
|
}
|
|
]
|
|
}
|
|
| minimal_dpo_cfg
|
|
)
|
|
|
|
# test that dpo.load works
|
|
load_dpo("chatml", cfg)
|
|
# now actually load the datasets with the strategy
|
|
train_ds, _ = load_prepare_preference_datasets(cfg)
|
|
assert train_ds[0]["prompt"].startswith("<|im_start|>")
|
|
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
|
|
assert "chosen" in train_ds[0]
|
|
assert "rejected" in train_ds[0]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|