Add chat_template.argilla_chat support for DPO datasets (#3202)
* Add chat_template.argilla_chat support for DPO datasets
Creates a new chat_template.argilla_chat prompt strategy for handling
DPO datasets where chosen/rejected fields contain full conversations
(messages + final response), following the pattern of chatml.argilla_chat
and llama3.argilla_chat.
- Add argilla_chat() function to chat_template.py
- Add chat_template.argilla_chat to RLHF documentation
- Add test coverage for argilla_chat with multiple tokenizers
Dataset format:
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
* Fix chat_template.argilla_chat return value contract and add docstring
- Return (transform_fn, dataset_kwargs) tuple instead of bare transform_fn
- Add remove_columns specification for field_chosen and field_rejected
- Add comprehensive docstring with Args/Returns sections
- Update tests to unpack tuple return value
Addresses PR feedback to maintain consistency with chat_template.default()
and properly specify columns to remove after dataset transformation.
* Update tests/prompt_strategies/test_dpo_chat_templates.py
Co-authored-by: Wing Lian <wing.lian@gmail.com>
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -8,7 +8,7 @@ import pytest
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.dpo.chat_template import default
|
||||
from axolotl.prompt_strategies.dpo.chat_template import argilla_chat, default
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
@@ -78,6 +78,36 @@ def fixture_custom_assistant_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="argilla_chat_dataset")
|
||||
def fixture_argilla_chat_dataset():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"chosen": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "goodbye",
|
||||
},
|
||||
],
|
||||
"rejected": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "party on",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="phi3_tokenizer")
|
||||
@enable_hf_offline
|
||||
def fixture_phi3_tokenizer():
|
||||
@@ -216,5 +246,51 @@ class TestAssistantDPOChatTemplateGemma:
|
||||
assert result["rejected"] == "party on<end_of_turn>"
|
||||
|
||||
|
||||
class TestArgillaChatDPOChatTemplate:
|
||||
"""
|
||||
Test class for argilla_chat style datasets (chosen/rejected contain full conversations).
|
||||
"""
|
||||
|
||||
def test_llama3_argilla_chat(self, llama3_tokenizer, argilla_chat_dataset):
|
||||
transform_fn, _ = argilla_chat(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"type": "chat_template.argilla_chat",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
result = transform_fn(argilla_chat_dataset[0], tokenizer=llama3_tokenizer)
|
||||
assert result["prompt"] == (
|
||||
"<|begin_of_text|>"
|
||||
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
||||
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
assert result["chosen"] == "goodbye<|eot_id|>"
|
||||
assert result["rejected"] == "party on<|eot_id|>"
|
||||
|
||||
def test_phi3_argilla_chat(self, phi3_tokenizer, argilla_chat_dataset):
|
||||
transform_fn, _ = argilla_chat(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "tokenizer_default",
|
||||
"datasets": [
|
||||
{
|
||||
"type": "chat_template.argilla_chat",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer)
|
||||
assert result["prompt"] == "<|user|>\nhello<|end|>\n" + "<|assistant|>\n"
|
||||
assert result["chosen"] == "goodbye<|end|>"
|
||||
assert result["rejected"] == "party on<|end|>"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user