diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 4a67b7559..594ebc743 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -219,6 +219,21 @@ DPO supports the following types with the following dataset format: } ``` +#### chat_template.argilla_chat + +```json +{ + "chosen": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ], + "rejected": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ] +} +``` + #### chat_template.default ```yaml diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index 85c4d2182..58b4d75bd 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -120,3 +120,123 @@ def default(cfg, dataset_idx=0, **kwargs): return result return transform_fn, {"remove_columns": [field_messages]} + + +def argilla_chat(cfg, dataset_idx=0, **kwargs): + """ + DPO chat template strategy for argilla-style datasets. + + For argilla-style datasets where chosen/rejected contain full conversations + instead of single response messages. Extracts the conversation history from + the chosen field and formats both chosen/rejected responses using the + configured chat template. + + Args: + cfg: Configuration object containing chat_template and dataset settings + dataset_idx: Index of the dataset in the config (default: 0) + **kwargs: Additional keyword arguments (unused) + + Returns: + tuple: (transform_fn, dataset_kwargs) where: + - transform_fn: Function to transform dataset samples + - dataset_kwargs: Dict with 'remove_columns' specifying columns to drop + + Dataset format: + { + "chosen": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ], + "rejected": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ] + } + """ + ds_cfg = cfg["datasets"][dataset_idx] + ds_cfg = handle_legacy_message_fields_logic(ds_cfg) + + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg=cfg, ds_cfg=ds_cfg + ) + field_chosen = ds_cfg.get("field_chosen", "chosen") + field_rejected = ds_cfg.get("field_rejected", "rejected") + message_property_mappings = ds_cfg.get( + "message_property_mappings", + { + "role": "role", + "content": "content", + }, + ) + role_map_inv = ds_cfg.get( + "roles", + { + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + ) + role_map = {} + for target, sources in role_map_inv.items(): + for source in sources: + role_map[source] = target + + def transform_fn(sample, tokenizer=None): + chat_template_string = get_chat_template( + user_choice=chat_template_choice, + jinja_template=chat_template_jinja, + tokenizer=tokenizer, + ) + + chosen_raw = sample[field_chosen] + rejected_raw = sample[field_rejected] + + # Extract messages (all but last) and responses (last message) + chosen_messages = [ + { + "role": role_map[m[message_property_mappings["role"]]], + "content": m[message_property_mappings["content"]], + } + for m in chosen_raw[:-1] + ] + chosen_response = { + "role": role_map[chosen_raw[-1][message_property_mappings["role"]]], + "content": chosen_raw[-1][message_property_mappings["content"]], + } + + rejected_response = { + "role": role_map[rejected_raw[-1][message_property_mappings["role"]]], + "content": rejected_raw[-1][message_property_mappings["content"]], + } + + dummy_user_message = {"role": "user", "content": "[[dummy_message]]"} + + result = {} + result["prompt"] = tokenizer.apply_chat_template( + chosen_messages, + add_generation_prompt=True, + chat_template=chat_template_string, + tokenize=False, + ) + + result["chosen"] = tokenizer.apply_chat_template( + [dummy_user_message, chosen_response], + add_generation_prompt=False, + chat_template=chat_template_string, + tokenize=False, + ) + chosen_strip_index = result["chosen"].find(chosen_response["content"]) + result["chosen"] = result["chosen"][chosen_strip_index:].rstrip() + + result["rejected"] = tokenizer.apply_chat_template( + [dummy_user_message, rejected_response], + add_generation_prompt=False, + chat_template=chat_template_string, + tokenize=False, + ) + rejected_strip_index = result["rejected"].find(rejected_response["content"]) + result["rejected"] = result["rejected"][rejected_strip_index:].rstrip() + + return result + + return transform_fn, {"remove_columns": [field_chosen, field_rejected]} diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py index e570cfc9d..b5c121726 100644 --- a/tests/prompt_strategies/test_dpo_chat_templates.py +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -8,7 +8,7 @@ import pytest from datasets import Dataset from transformers import AutoTokenizer -from axolotl.prompt_strategies.dpo.chat_template import default +from axolotl.prompt_strategies.dpo.chat_template import argilla_chat, default from axolotl.utils.dict import DictDefault from tests.hf_offline_utils import enable_hf_offline @@ -78,6 +78,36 @@ def fixture_custom_assistant_dataset(): ) +@pytest.fixture(name="argilla_chat_dataset") +def fixture_argilla_chat_dataset(): + return Dataset.from_list( + [ + { + "chosen": [ + { + "role": "user", + "content": "hello", + }, + { + "role": "assistant", + "content": "goodbye", + }, + ], + "rejected": [ + { + "role": "user", + "content": "hello", + }, + { + "role": "assistant", + "content": "party on", + }, + ], + } + ] + ) + + @pytest.fixture(name="phi3_tokenizer") @enable_hf_offline def fixture_phi3_tokenizer(): @@ -216,5 +246,51 @@ class TestAssistantDPOChatTemplateGemma: assert result["rejected"] == "party on" +class TestArgillaChatDPOChatTemplate: + """ + Test class for argilla_chat style datasets (chosen/rejected contain full conversations). + """ + + def test_llama3_argilla_chat(self, llama3_tokenizer, argilla_chat_dataset): + transform_fn, _ = argilla_chat( + DictDefault( + { + "chat_template": "llama3", + "datasets": [ + { + "type": "chat_template.argilla_chat", + } + ], + } + ) + ) + result = transform_fn(argilla_chat_dataset[0], tokenizer=llama3_tokenizer) + assert result["prompt"] == ( + "<|begin_of_text|>" + + "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert result["chosen"] == "goodbye<|eot_id|>" + assert result["rejected"] == "party on<|eot_id|>" + + def test_phi3_argilla_chat(self, phi3_tokenizer, argilla_chat_dataset): + transform_fn, _ = argilla_chat( + DictDefault( + { + "chat_template": "tokenizer_default", + "datasets": [ + { + "type": "chat_template.argilla_chat", + } + ], + } + ) + ) + result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer) + assert result["prompt"] == "<|user|>\nhello<|end|>\n" + "<|assistant|>\n" + assert result["chosen"] == "goodbye<|end|>" + assert result["rejected"] == "party on<|end|>" + + if __name__ == "__main__": unittest.main()