fix: customized dataset with simpo (#2894) [skip ci]

This commit is contained in:
Jiawei Liu
2025-07-12 10:40:30 -05:00
committed by GitHub
parent 4dc5910e1c
commit 7fb8441e0e
3 changed files with 17 additions and 19 deletions

View File

@@ -274,15 +274,14 @@ rl: dpo
datasets: datasets:
- path: ... - path: ...
split: train split: train
type: user_defined.default type:
field_prompt: "prompt"
field_prompt: "prompt" field_system: "system"
field_system: "system" field_chosen: "chosen"
field_chosen: "chosen" field_rejected: "rejected"
field_rejected: "rejected" prompt_format: "{prompt}"
prompt_format: "{prompt}" chosen_format: "{chosen}"
chosen_format: "{chosen}" rejected_format: "{rejected}"
rejected_format: "{rejected}"
``` ```
The input format is a simple JSON input with customizable fields based on the above config. The input format is a simple JSON input with customizable fields based on the above config.
@@ -475,14 +474,13 @@ rl: kto
datasets: datasets:
- path: ... - path: ...
split: train split: train
type: user_defined.default type:
field_prompt: "prompt"
field_prompt: "prompt" field_system: "system"
field_system: "system" field_completion: "completion"
field_completion: "completion" field_label: "label"
field_label: "label" prompt_format: "{prompt}"
prompt_format: "{prompt}" completion_format: "{completion}"
completion_format: "{completion}"
``` ```
The input format is a simple JSON input with customizable fields based on the above config. The input format is a simple JSON input with customizable fields based on the above config.

View File

@@ -33,7 +33,7 @@ def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
system=sample[field_system], prompt=sample[field_prompt] system=sample[field_system], prompt=sample[field_prompt]
) )
else: else:
sample["prompt"] = prompt_format.format(prompt=sample["prompt"]) sample["prompt"] = prompt_format.format(prompt=sample[field_prompt])
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen]) sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected]) sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
return sample return sample

View File

@@ -274,7 +274,7 @@ def validate_config(
# Convert datasets to proper format if needed # Convert datasets to proper format if needed
if cfg.get("datasets"): if cfg.get("datasets"):
for idx, ds_cfg in enumerate(cfg["datasets"]): for idx, ds_cfg in enumerate(cfg["datasets"]):
if cfg.get("rl") == "dpo" and not isinstance(ds_cfg, DPODataset): if cfg.get("rl") in ["dpo", "simpo"] and not isinstance(ds_cfg, DPODataset):
cfg["datasets"][idx] = DPODataset(**ds_cfg) cfg["datasets"][idx] = DPODataset(**ds_cfg)
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset): elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg)) cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))