From 87565ecc05f1b8fd1f8b907dd750d3a5d09adf9a Mon Sep 17 00:00:00 2001 From: Leonard Date: Fri, 17 Oct 2025 19:00:26 +0900 Subject: [PATCH] Add chat_template.argilla_chat support for DPO datasets (#3202) * Add chat_template.argilla_chat support for DPO datasets Creates a new chat_template.argilla_chat prompt strategy for handling DPO datasets where chosen/rejected fields contain full conversations (messages + final response), following the pattern of chatml.argilla_chat and llama3.argilla_chat. - Add argilla_chat() function to chat_template.py - Add chat_template.argilla_chat to RLHF documentation - Add test coverage for argilla_chat with multiple tokenizers Dataset format: { "chosen": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ], "rejected": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } * Fix chat_template.argilla_chat return value contract and add docstring - Return (transform_fn, dataset_kwargs) tuple instead of bare transform_fn - Add remove_columns specification for field_chosen and field_rejected - Add comprehensive docstring with Args/Returns sections - Update tests to unpack tuple return value Addresses PR feedback to maintain consistency with chat_template.default() and properly specify columns to remove after dataset transformation. * Update tests/prompt_strategies/test_dpo_chat_templates.py Co-authored-by: Wing Lian --------- Co-authored-by: Wing Lian --- docs/rlhf.qmd | 15 +++ .../prompt_strategies/dpo/chat_template.py | 120 ++++++++++++++++++ .../test_dpo_chat_templates.py | 78 +++++++++++- 3 files changed, 212 insertions(+), 1 deletion(-) diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 4a67b7559..594ebc743 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -219,6 +219,21 @@ DPO supports the following types with the following dataset format: } ``` +#### chat_template.argilla_chat + +```json +{ + "chosen": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ], + "rejected": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ] +} +``` + #### chat_template.default ```yaml diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index 85c4d2182..58b4d75bd 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -120,3 +120,123 @@ def default(cfg, dataset_idx=0, **kwargs): return result return transform_fn, {"remove_columns": [field_messages]} + + +def argilla_chat(cfg, dataset_idx=0, **kwargs): + """ + DPO chat template strategy for argilla-style datasets. + + For argilla-style datasets where chosen/rejected contain full conversations + instead of single response messages. Extracts the conversation history from + the chosen field and formats both chosen/rejected responses using the + configured chat template. + + Args: + cfg: Configuration object containing chat_template and dataset settings + dataset_idx: Index of the dataset in the config (default: 0) + **kwargs: Additional keyword arguments (unused) + + Returns: + tuple: (transform_fn, dataset_kwargs) where: + - transform_fn: Function to transform dataset samples + - dataset_kwargs: Dict with 'remove_columns' specifying columns to drop + + Dataset format: + { + "chosen": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ], + "rejected": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ] + } + """ + ds_cfg = cfg["datasets"][dataset_idx] + ds_cfg = handle_legacy_message_fields_logic(ds_cfg) + + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg=cfg, ds_cfg=ds_cfg + ) + field_chosen = ds_cfg.get("field_chosen", "chosen") + field_rejected = ds_cfg.get("field_rejected", "rejected") + message_property_mappings = ds_cfg.get( + "message_property_mappings", + { + "role": "role", + "content": "content", + }, + ) + role_map_inv = ds_cfg.get( + "roles", + { + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + ) + role_map = {} + for target, sources in role_map_inv.items(): + for source in sources: + role_map[source] = target + + def transform_fn(sample, tokenizer=None): + chat_template_string = get_chat_template( + user_choice=chat_template_choice, + jinja_template=chat_template_jinja, + tokenizer=tokenizer, + ) + + chosen_raw = sample[field_chosen] + rejected_raw = sample[field_rejected] + + # Extract messages (all but last) and responses (last message) + chosen_messages = [ + { + "role": role_map[m[message_property_mappings["role"]]], + "content": m[message_property_mappings["content"]], + } + for m in chosen_raw[:-1] + ] + chosen_response = { + "role": role_map[chosen_raw[-1][message_property_mappings["role"]]], + "content": chosen_raw[-1][message_property_mappings["content"]], + } + + rejected_response = { + "role": role_map[rejected_raw[-1][message_property_mappings["role"]]], + "content": rejected_raw[-1][message_property_mappings["content"]], + } + + dummy_user_message = {"role": "user", "content": "[[dummy_message]]"} + + result = {} + result["prompt"] = tokenizer.apply_chat_template( + chosen_messages, + add_generation_prompt=True, + chat_template=chat_template_string, + tokenize=False, + ) + + result["chosen"] = tokenizer.apply_chat_template( + [dummy_user_message, chosen_response], + add_generation_prompt=False, + chat_template=chat_template_string, + tokenize=False, + ) + chosen_strip_index = result["chosen"].find(chosen_response["content"]) + result["chosen"] = result["chosen"][chosen_strip_index:].rstrip() + + result["rejected"] = tokenizer.apply_chat_template( + [dummy_user_message, rejected_response], + add_generation_prompt=False, + chat_template=chat_template_string, + tokenize=False, + ) + rejected_strip_index = result["rejected"].find(rejected_response["content"]) + result["rejected"] = result["rejected"][rejected_strip_index:].rstrip() + + return result + + return transform_fn, {"remove_columns": [field_chosen, field_rejected]} diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py index e570cfc9d..b5c121726 100644 --- a/tests/prompt_strategies/test_dpo_chat_templates.py +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -8,7 +8,7 @@ import pytest from datasets import Dataset from transformers import AutoTokenizer -from axolotl.prompt_strategies.dpo.chat_template import default +from axolotl.prompt_strategies.dpo.chat_template import argilla_chat, default from axolotl.utils.dict import DictDefault from tests.hf_offline_utils import enable_hf_offline @@ -78,6 +78,36 @@ def fixture_custom_assistant_dataset(): ) +@pytest.fixture(name="argilla_chat_dataset") +def fixture_argilla_chat_dataset(): + return Dataset.from_list( + [ + { + "chosen": [ + { + "role": "user", + "content": "hello", + }, + { + "role": "assistant", + "content": "goodbye", + }, + ], + "rejected": [ + { + "role": "user", + "content": "hello", + }, + { + "role": "assistant", + "content": "party on", + }, + ], + } + ] + ) + + @pytest.fixture(name="phi3_tokenizer") @enable_hf_offline def fixture_phi3_tokenizer(): @@ -216,5 +246,51 @@ class TestAssistantDPOChatTemplateGemma: assert result["rejected"] == "party on" +class TestArgillaChatDPOChatTemplate: + """ + Test class for argilla_chat style datasets (chosen/rejected contain full conversations). + """ + + def test_llama3_argilla_chat(self, llama3_tokenizer, argilla_chat_dataset): + transform_fn, _ = argilla_chat( + DictDefault( + { + "chat_template": "llama3", + "datasets": [ + { + "type": "chat_template.argilla_chat", + } + ], + } + ) + ) + result = transform_fn(argilla_chat_dataset[0], tokenizer=llama3_tokenizer) + assert result["prompt"] == ( + "<|begin_of_text|>" + + "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert result["chosen"] == "goodbye<|eot_id|>" + assert result["rejected"] == "party on<|eot_id|>" + + def test_phi3_argilla_chat(self, phi3_tokenizer, argilla_chat_dataset): + transform_fn, _ = argilla_chat( + DictDefault( + { + "chat_template": "tokenizer_default", + "datasets": [ + { + "type": "chat_template.argilla_chat", + } + ], + } + ) + ) + result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer) + assert result["prompt"] == "<|user|>\nhello<|end|>\n" + "<|assistant|>\n" + assert result["chosen"] == "goodbye<|end|>" + assert result["rejected"] == "party on<|end|>" + + if __name__ == "__main__": unittest.main()