diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index f04bd7f0d..f3427022f 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -91,4 +91,4 @@ def default( return result - return transform_fn + return transform_fn, {"remove_columns": [field_messages]} diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py index b1802faa0..e5f30a6c4 100644 --- a/tests/prompt_strategies/test_dpo_chat_templates.py +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -103,7 +103,7 @@ class TestAssistantDPOChatTemplateLlama3: def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset): # pylint: disable=duplicate-code - transform_fn = default( + transform_fn, _ = default( DictDefault( { "chat_template": "llama3", @@ -128,7 +128,7 @@ class TestAssistantDPOChatTemplateLlama3: def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset): # pylint: disable=duplicate-code - transform_fn = default( + transform_fn, _ = default( DictDefault( { "chat_template": "llama3", @@ -169,7 +169,7 @@ class TestAssistantDPOChatTemplatePhi3: def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset): # pylint: disable=duplicate-code - transform_fn = default( + transform_fn, _ = default( DictDefault( { "chat_template": "tokenizer_default", @@ -199,7 +199,7 @@ class TestAssistantDPOChatTemplateGemma: def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset): # pylint: disable=duplicate-code - transform_fn = default( + transform_fn, _ = default( DictDefault( { "chat_template": "tokenizer_default", diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 88d196ad1..bd77591cf 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -289,7 +289,10 @@ class TestDatasetPreparation: train_dataset, _ = load_prepare_preference_datasets(cfg) assert len(train_dataset) == 1800 - assert "conversation" in train_dataset.features + assert "conversation" not in train_dataset.features + assert "chosen" in train_dataset.features + assert "rejected" in train_dataset.features + assert "prompt" in train_dataset.features @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits") @enable_hf_offline @@ -348,7 +351,10 @@ class TestDatasetPreparation: train_dataset, _ = load_prepare_preference_datasets(cfg) assert len(train_dataset) == 1800 - assert "conversation" in train_dataset.features + assert "conversation" not in train_dataset.features + assert "chosen" in train_dataset.features + assert "rejected" in train_dataset.features + assert "prompt" in train_dataset.features @enable_hf_offline @pytest.mark.skip("datasets bug with local datasets when offline")