Add chat_template.argilla_chat support for DPO datasets (#3202)
* Add chat_template.argilla_chat support for DPO datasets
Creates a new chat_template.argilla_chat prompt strategy for handling
DPO datasets where chosen/rejected fields contain full conversations
(messages + final response), following the pattern of chatml.argilla_chat
and llama3.argilla_chat.
- Add argilla_chat() function to chat_template.py
- Add chat_template.argilla_chat to RLHF documentation
- Add test coverage for argilla_chat with multiple tokenizers
Dataset format:
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
* Fix chat_template.argilla_chat return value contract and add docstring
- Return (transform_fn, dataset_kwargs) tuple instead of bare transform_fn
- Add remove_columns specification for field_chosen and field_rejected
- Add comprehensive docstring with Args/Returns sections
- Update tests to unpack tuple return value
Addresses PR feedback to maintain consistency with chat_template.default()
and properly specify columns to remove after dataset transformation.
* Update tests/prompt_strategies/test_dpo_chat_templates.py
Co-authored-by: Wing Lian <wing.lian@gmail.com>
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -219,6 +219,21 @@ DPO supports the following types with the following dataset format:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### chat_template.argilla_chat
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"chosen": [
|
||||||
|
{"role": "user", "content": "..."},
|
||||||
|
{"role": "assistant", "content": "..."}
|
||||||
|
],
|
||||||
|
"rejected": [
|
||||||
|
{"role": "user", "content": "..."},
|
||||||
|
{"role": "assistant", "content": "..."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
#### chat_template.default
|
#### chat_template.default
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
@@ -120,3 +120,123 @@ def default(cfg, dataset_idx=0, **kwargs):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
return transform_fn, {"remove_columns": [field_messages]}
|
return transform_fn, {"remove_columns": [field_messages]}
|
||||||
|
|
||||||
|
|
||||||
|
def argilla_chat(cfg, dataset_idx=0, **kwargs):
|
||||||
|
"""
|
||||||
|
DPO chat template strategy for argilla-style datasets.
|
||||||
|
|
||||||
|
For argilla-style datasets where chosen/rejected contain full conversations
|
||||||
|
instead of single response messages. Extracts the conversation history from
|
||||||
|
the chosen field and formats both chosen/rejected responses using the
|
||||||
|
configured chat template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Configuration object containing chat_template and dataset settings
|
||||||
|
dataset_idx: Index of the dataset in the config (default: 0)
|
||||||
|
**kwargs: Additional keyword arguments (unused)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (transform_fn, dataset_kwargs) where:
|
||||||
|
- transform_fn: Function to transform dataset samples
|
||||||
|
- dataset_kwargs: Dict with 'remove_columns' specifying columns to drop
|
||||||
|
|
||||||
|
Dataset format:
|
||||||
|
{
|
||||||
|
"chosen": [
|
||||||
|
{"role": "user", "content": "..."},
|
||||||
|
{"role": "assistant", "content": "..."}
|
||||||
|
],
|
||||||
|
"rejected": [
|
||||||
|
{"role": "user", "content": "..."},
|
||||||
|
{"role": "assistant", "content": "..."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
ds_cfg = cfg["datasets"][dataset_idx]
|
||||||
|
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
|
||||||
|
|
||||||
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||||
|
cfg=cfg, ds_cfg=ds_cfg
|
||||||
|
)
|
||||||
|
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||||
|
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||||
|
message_property_mappings = ds_cfg.get(
|
||||||
|
"message_property_mappings",
|
||||||
|
{
|
||||||
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
role_map_inv = ds_cfg.get(
|
||||||
|
"roles",
|
||||||
|
{
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
"system": ["system"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
role_map = {}
|
||||||
|
for target, sources in role_map_inv.items():
|
||||||
|
for source in sources:
|
||||||
|
role_map[source] = target
|
||||||
|
|
||||||
|
def transform_fn(sample, tokenizer=None):
|
||||||
|
chat_template_string = get_chat_template(
|
||||||
|
user_choice=chat_template_choice,
|
||||||
|
jinja_template=chat_template_jinja,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
chosen_raw = sample[field_chosen]
|
||||||
|
rejected_raw = sample[field_rejected]
|
||||||
|
|
||||||
|
# Extract messages (all but last) and responses (last message)
|
||||||
|
chosen_messages = [
|
||||||
|
{
|
||||||
|
"role": role_map[m[message_property_mappings["role"]]],
|
||||||
|
"content": m[message_property_mappings["content"]],
|
||||||
|
}
|
||||||
|
for m in chosen_raw[:-1]
|
||||||
|
]
|
||||||
|
chosen_response = {
|
||||||
|
"role": role_map[chosen_raw[-1][message_property_mappings["role"]]],
|
||||||
|
"content": chosen_raw[-1][message_property_mappings["content"]],
|
||||||
|
}
|
||||||
|
|
||||||
|
rejected_response = {
|
||||||
|
"role": role_map[rejected_raw[-1][message_property_mappings["role"]]],
|
||||||
|
"content": rejected_raw[-1][message_property_mappings["content"]],
|
||||||
|
}
|
||||||
|
|
||||||
|
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
result["prompt"] = tokenizer.apply_chat_template(
|
||||||
|
chosen_messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
chat_template=chat_template_string,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result["chosen"] = tokenizer.apply_chat_template(
|
||||||
|
[dummy_user_message, chosen_response],
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=chat_template_string,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
chosen_strip_index = result["chosen"].find(chosen_response["content"])
|
||||||
|
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
|
||||||
|
|
||||||
|
result["rejected"] = tokenizer.apply_chat_template(
|
||||||
|
[dummy_user_message, rejected_response],
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=chat_template_string,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
rejected_strip_index = result["rejected"].find(rejected_response["content"])
|
||||||
|
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return transform_fn, {"remove_columns": [field_chosen, field_rejected]}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import pytest
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.prompt_strategies.dpo.chat_template import default
|
from axolotl.prompt_strategies.dpo.chat_template import argilla_chat, default
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
@@ -78,6 +78,36 @@ def fixture_custom_assistant_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="argilla_chat_dataset")
|
||||||
|
def fixture_argilla_chat_dataset():
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"chosen": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "goodbye",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"rejected": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "party on",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="phi3_tokenizer")
|
@pytest.fixture(name="phi3_tokenizer")
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def fixture_phi3_tokenizer():
|
def fixture_phi3_tokenizer():
|
||||||
@@ -216,5 +246,51 @@ class TestAssistantDPOChatTemplateGemma:
|
|||||||
assert result["rejected"] == "party on<end_of_turn>"
|
assert result["rejected"] == "party on<end_of_turn>"
|
||||||
|
|
||||||
|
|
||||||
|
class TestArgillaChatDPOChatTemplate:
|
||||||
|
"""
|
||||||
|
Test class for argilla_chat style datasets (chosen/rejected contain full conversations).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_llama3_argilla_chat(self, llama3_tokenizer, argilla_chat_dataset):
|
||||||
|
transform_fn, _ = argilla_chat(
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"type": "chat_template.argilla_chat",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = transform_fn(argilla_chat_dataset[0], tokenizer=llama3_tokenizer)
|
||||||
|
assert result["prompt"] == (
|
||||||
|
"<|begin_of_text|>"
|
||||||
|
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
assert result["chosen"] == "goodbye<|eot_id|>"
|
||||||
|
assert result["rejected"] == "party on<|eot_id|>"
|
||||||
|
|
||||||
|
def test_phi3_argilla_chat(self, phi3_tokenizer, argilla_chat_dataset):
|
||||||
|
transform_fn, _ = argilla_chat(
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"chat_template": "tokenizer_default",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"type": "chat_template.argilla_chat",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer)
|
||||||
|
assert result["prompt"] == "<|user|>\nhello<|end|>\n" + "<|assistant|>\n"
|
||||||
|
assert result["chosen"] == "goodbye<|end|>"
|
||||||
|
assert result["rejected"] == "party on<|end|>"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user