remove unused field for chat_template.default for DPO training (#2755) [skip ci]
* remove unused field for chat_template.default "messages" field present in final dataset causes issues with DPO training otherwise * lint and fix tests for new return value * remove unused field for chat_template.default "messages" field present in final dataset causes issues with DPO training otherwise lint and fix tests for new return value fix for updated expected fields for dpo remove unused field for chat_template.default "messages" field present in final dataset causes issues with DPO training otherwise fix test still expecting "messages" field * chore: lint --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -91,4 +91,4 @@ def default(
|
||||
|
||||
return result
|
||||
|
||||
return transform_fn
|
||||
return transform_fn, {"remove_columns": [field_messages]}
|
||||
|
||||
@@ -103,7 +103,7 @@ class TestAssistantDPOChatTemplateLlama3:
|
||||
|
||||
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn = default(
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
@@ -128,7 +128,7 @@ class TestAssistantDPOChatTemplateLlama3:
|
||||
|
||||
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn = default(
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
@@ -169,7 +169,7 @@ class TestAssistantDPOChatTemplatePhi3:
|
||||
|
||||
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn = default(
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "tokenizer_default",
|
||||
@@ -199,7 +199,7 @@ class TestAssistantDPOChatTemplateGemma:
|
||||
|
||||
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn = default(
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "tokenizer_default",
|
||||
|
||||
@@ -289,7 +289,10 @@ class TestDatasetPreparation:
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" in train_dataset.features
|
||||
assert "conversation" not in train_dataset.features
|
||||
assert "chosen" in train_dataset.features
|
||||
assert "rejected" in train_dataset.features
|
||||
assert "prompt" in train_dataset.features
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
@enable_hf_offline
|
||||
@@ -348,7 +351,10 @@ class TestDatasetPreparation:
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" in train_dataset.features
|
||||
assert "conversation" not in train_dataset.features
|
||||
assert "chosen" in train_dataset.features
|
||||
assert "rejected" in train_dataset.features
|
||||
assert "prompt" in train_dataset.features
|
||||
|
||||
@enable_hf_offline
|
||||
@pytest.mark.skip("datasets bug with local datasets when offline")
|
||||
|
||||
Reference in New Issue
Block a user