remove unused field for chat_template.default for DPO training (#2755) [skip ci]

* remove unused field for chat_template.default

"messages" field present in final dataset causes issues with DPO
training otherwise

* lint and fix tests for new return value

* remove unused field for chat_template.default

"messages" field present in final dataset causes issues with DPO
training otherwise

lint and fix tests for new return value

fix for updated expected fields for dpo

remove unused field for chat_template.default

"messages" field present in final dataset causes issues with DPO
training otherwise

fix test still expecting "messages" field

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Timofey Klyubin
2025-06-05 10:22:58 -04:00
committed by GitHub
parent e8e45b3441
commit 4440b4a1ce
3 changed files with 13 additions and 7 deletions

View File

@@ -91,4 +91,4 @@ def default(
return result
return transform_fn
return transform_fn, {"remove_columns": [field_messages]}

View File

@@ -103,7 +103,7 @@ class TestAssistantDPOChatTemplateLlama3:
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
transform_fn, _ = default(
DictDefault(
{
"chat_template": "llama3",
@@ -128,7 +128,7 @@ class TestAssistantDPOChatTemplateLlama3:
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
transform_fn, _ = default(
DictDefault(
{
"chat_template": "llama3",
@@ -169,7 +169,7 @@ class TestAssistantDPOChatTemplatePhi3:
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
transform_fn, _ = default(
DictDefault(
{
"chat_template": "tokenizer_default",
@@ -199,7 +199,7 @@ class TestAssistantDPOChatTemplateGemma:
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
transform_fn, _ = default(
DictDefault(
{
"chat_template": "tokenizer_default",

View File

@@ -289,7 +289,10 @@ class TestDatasetPreparation:
train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
assert "conversation" not in train_dataset.features
assert "chosen" in train_dataset.features
assert "rejected" in train_dataset.features
assert "prompt" in train_dataset.features
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
@@ -348,7 +351,10 @@ class TestDatasetPreparation:
train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
assert "conversation" not in train_dataset.features
assert "chosen" in train_dataset.features
assert "rejected" in train_dataset.features
assert "prompt" in train_dataset.features
@enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")