fix: customized dataset with simpo (#2894) [skip ci]

This commit is contained in:
Jiawei Liu
2025-07-12 10:40:30 -05:00
committed by GitHub
parent 4dc5910e1c
commit 7fb8441e0e
3 changed files with 17 additions and 19 deletions

View File

@@ -33,7 +33,7 @@ def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
system=sample[field_system], prompt=sample[field_prompt]
)
else:
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
sample["prompt"] = prompt_format.format(prompt=sample[field_prompt])
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
return sample

View File

@@ -274,7 +274,7 @@ def validate_config(
# Convert datasets to proper format if needed
if cfg.get("datasets"):
for idx, ds_cfg in enumerate(cfg["datasets"]):
if cfg.get("rl") == "dpo" and not isinstance(ds_cfg, DPODataset):
if cfg.get("rl") in ["dpo", "simpo"] and not isinstance(ds_cfg, DPODataset):
cfg["datasets"][idx] = DPODataset(**ds_cfg)
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))