Add chat_template.argilla_chat support for DPO datasets (#3202)

* Add chat_template.argilla_chat support for DPO datasets

  Creates a new chat_template.argilla_chat prompt strategy for handling
  DPO datasets where chosen/rejected fields contain full conversations
  (messages + final response), following the pattern of chatml.argilla_chat
  and llama3.argilla_chat.

  - Add argilla_chat() function to chat_template.py
  - Add chat_template.argilla_chat to RLHF documentation
  - Add test coverage for argilla_chat with multiple tokenizers

  Dataset format:
  {
    "chosen": [
      {"role": "user", "content": "..."},
      {"role": "assistant", "content": "..."}
    ],
    "rejected": [
      {"role": "user", "content": "..."},
      {"role": "assistant", "content": "..."}
    ]
  }

* Fix chat_template.argilla_chat return value contract and add docstring

- Return (transform_fn, dataset_kwargs) tuple instead of bare transform_fn
- Add remove_columns specification for field_chosen and field_rejected
- Add comprehensive docstring with Args/Returns sections
- Update tests to unpack tuple return value

Addresses PR feedback to maintain consistency with chat_template.default()
and properly specify columns to remove after dataset transformation.

* Update tests/prompt_strategies/test_dpo_chat_templates.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Leonard
2025-10-17 19:00:26 +09:00
committed by GitHub
parent 93ba57396f
commit 87565ecc05
3 changed files with 212 additions and 1 deletions

View File

@@ -8,7 +8,7 @@ import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.prompt_strategies.dpo.chat_template import default
from axolotl.prompt_strategies.dpo.chat_template import argilla_chat, default
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@@ -78,6 +78,36 @@ def fixture_custom_assistant_dataset():
)
@pytest.fixture(name="argilla_chat_dataset")
def fixture_argilla_chat_dataset():
return Dataset.from_list(
[
{
"chosen": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "goodbye",
},
],
"rejected": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "party on",
},
],
}
]
)
@pytest.fixture(name="phi3_tokenizer")
@enable_hf_offline
def fixture_phi3_tokenizer():
@@ -216,5 +246,51 @@ class TestAssistantDPOChatTemplateGemma:
assert result["rejected"] == "party on<end_of_turn>"
class TestArgillaChatDPOChatTemplate:
"""
Test class for argilla_chat style datasets (chosen/rejected contain full conversations).
"""
def test_llama3_argilla_chat(self, llama3_tokenizer, argilla_chat_dataset):
transform_fn, _ = argilla_chat(
DictDefault(
{
"chat_template": "llama3",
"datasets": [
{
"type": "chat_template.argilla_chat",
}
],
}
)
)
result = transform_fn(argilla_chat_dataset[0], tokenizer=llama3_tokenizer)
assert result["prompt"] == (
"<|begin_of_text|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert result["chosen"] == "goodbye<|eot_id|>"
assert result["rejected"] == "party on<|eot_id|>"
def test_phi3_argilla_chat(self, phi3_tokenizer, argilla_chat_dataset):
transform_fn, _ = argilla_chat(
DictDefault(
{
"chat_template": "tokenizer_default",
"datasets": [
{
"type": "chat_template.argilla_chat",
}
],
}
)
)
result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer)
assert result["prompt"] == "<|user|>\nhello<|end|>\n" + "<|assistant|>\n"
assert result["chosen"] == "goodbye<|end|>"
assert result["rejected"] == "party on<|end|>"
if __name__ == "__main__":
unittest.main()