From 4440b4a1ce9ee0e1702ab17f6e1852985c8220a3 Mon Sep 17 00:00:00 2001 From: Timofey Klyubin Date: Thu, 5 Jun 2025 10:22:58 -0400 Subject: [PATCH] remove unused field for chat_template.default for DPO training (#2755) [skip ci] * remove unused field for chat_template.default "messages" field present in final dataset causes issues with DPO training otherwise * lint and fix tests for new return value * remove unused field for chat_template.default "messages" field present in final dataset causes issues with DPO training otherwise lint and fix tests for new return value fix for updated expected fields for dpo remove unused field for chat_template.default "messages" field present in final dataset causes issues with DPO training otherwise fix test still expecting "messages" field * chore: lint --------- Co-authored-by: Wing Lian --- src/axolotl/prompt_strategies/dpo/chat_template.py | 2 +- tests/prompt_strategies/test_dpo_chat_templates.py | 8 ++++---- tests/test_datasets.py | 10 ++++++++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index f04bd7f0d..f3427022f 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -91,4 +91,4 @@ def default( return result - return transform_fn + return transform_fn, {"remove_columns": [field_messages]} diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py index b1802faa0..e5f30a6c4 100644 --- a/tests/prompt_strategies/test_dpo_chat_templates.py +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -103,7 +103,7 @@ class TestAssistantDPOChatTemplateLlama3: def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset): # pylint: disable=duplicate-code - transform_fn = default( + transform_fn, _ = default( DictDefault( { "chat_template": "llama3", @@ -128,7 +128,7 @@ class TestAssistantDPOChatTemplateLlama3: def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset): # pylint: disable=duplicate-code - transform_fn = default( + transform_fn, _ = default( DictDefault( { "chat_template": "llama3", @@ -169,7 +169,7 @@ class TestAssistantDPOChatTemplatePhi3: def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset): # pylint: disable=duplicate-code - transform_fn = default( + transform_fn, _ = default( DictDefault( { "chat_template": "tokenizer_default", @@ -199,7 +199,7 @@ class TestAssistantDPOChatTemplateGemma: def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset): # pylint: disable=duplicate-code - transform_fn = default( + transform_fn, _ = default( DictDefault( { "chat_template": "tokenizer_default", diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 88d196ad1..bd77591cf 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -289,7 +289,10 @@ class TestDatasetPreparation: train_dataset, _ = load_prepare_preference_datasets(cfg) assert len(train_dataset) == 1800 - assert "conversation" in train_dataset.features + assert "conversation" not in train_dataset.features + assert "chosen" in train_dataset.features + assert "rejected" in train_dataset.features + assert "prompt" in train_dataset.features @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits") @enable_hf_offline @@ -348,7 +351,10 @@ class TestDatasetPreparation: train_dataset, _ = load_prepare_preference_datasets(cfg) assert len(train_dataset) == 1800 - assert "conversation" in train_dataset.features + assert "conversation" not in train_dataset.features + assert "chosen" in train_dataset.features + assert "rejected" in train_dataset.features + assert "prompt" in train_dataset.features @enable_hf_offline @pytest.mark.skip("datasets bug with local datasets when offline")