diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py index e8c7f4088..585696e29 100644 --- a/src/axolotl/prompt_strategies/dpo/chatml.py +++ b/src/axolotl/prompt_strategies/dpo/chatml.py @@ -24,6 +24,25 @@ def argilla( return transform_fn +def argilla_chat( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + """ + for argilla/dpo-mix-7k conversations + """ + + def transform_fn(sample): + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" + sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" + return sample + + return transform_fn + + def icr( cfg, **kwargs,