From 7fb8441e0e0453a3b760996221332fed0a7cdb37 Mon Sep 17 00:00:00 2001 From: Jiawei Liu Date: Sat, 12 Jul 2025 10:40:30 -0500 Subject: [PATCH] fix: customized dataset with simpo (#2894) [skip ci] --- docs/rlhf.qmd | 32 +++++++++---------- .../prompt_strategies/dpo/user_defined.py | 2 +- src/axolotl/utils/config/__init__.py | 2 +- 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 76131978f..4a67b7559 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -274,15 +274,14 @@ rl: dpo datasets: - path: ... split: train - type: user_defined.default - - field_prompt: "prompt" - field_system: "system" - field_chosen: "chosen" - field_rejected: "rejected" - prompt_format: "{prompt}" - chosen_format: "{chosen}" - rejected_format: "{rejected}" + type: + field_prompt: "prompt" + field_system: "system" + field_chosen: "chosen" + field_rejected: "rejected" + prompt_format: "{prompt}" + chosen_format: "{chosen}" + rejected_format: "{rejected}" ``` The input format is a simple JSON input with customizable fields based on the above config. @@ -475,14 +474,13 @@ rl: kto datasets: - path: ... split: train - type: user_defined.default - - field_prompt: "prompt" - field_system: "system" - field_completion: "completion" - field_label: "label" - prompt_format: "{prompt}" - completion_format: "{completion}" + type: + field_prompt: "prompt" + field_system: "system" + field_completion: "completion" + field_label: "label" + prompt_format: "{prompt}" + completion_format: "{completion}" ``` The input format is a simple JSON input with customizable fields based on the above config. diff --git a/src/axolotl/prompt_strategies/dpo/user_defined.py b/src/axolotl/prompt_strategies/dpo/user_defined.py index 1d5f891af..cdd9b8c9c 100644 --- a/src/axolotl/prompt_strategies/dpo/user_defined.py +++ b/src/axolotl/prompt_strategies/dpo/user_defined.py @@ -33,7 +33,7 @@ def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument system=sample[field_system], prompt=sample[field_prompt] ) else: - sample["prompt"] = prompt_format.format(prompt=sample["prompt"]) + sample["prompt"] = prompt_format.format(prompt=sample[field_prompt]) sample["chosen"] = chosen_format.format(chosen=sample[field_chosen]) sample["rejected"] = rejected_format.format(rejected=sample[field_rejected]) return sample diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index fb17b259f..4de606565 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -274,7 +274,7 @@ def validate_config( # Convert datasets to proper format if needed if cfg.get("datasets"): for idx, ds_cfg in enumerate(cfg["datasets"]): - if cfg.get("rl") == "dpo" and not isinstance(ds_cfg, DPODataset): + if cfg.get("rl") in ["dpo", "simpo"] and not isinstance(ds_cfg, DPODataset): cfg["datasets"][idx] = DPODataset(**ds_cfg) elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset): cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))