fix: customized dataset with simpo (#2894) [skip ci]
This commit is contained in:
@@ -33,7 +33,7 @@ def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
|
||||
system=sample[field_system], prompt=sample[field_prompt]
|
||||
)
|
||||
else:
|
||||
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
|
||||
sample["prompt"] = prompt_format.format(prompt=sample[field_prompt])
|
||||
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
|
||||
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
|
||||
return sample
|
||||
|
||||
@@ -274,7 +274,7 @@ def validate_config(
|
||||
# Convert datasets to proper format if needed
|
||||
if cfg.get("datasets"):
|
||||
for idx, ds_cfg in enumerate(cfg["datasets"]):
|
||||
if cfg.get("rl") == "dpo" and not isinstance(ds_cfg, DPODataset):
|
||||
if cfg.get("rl") in ["dpo", "simpo"] and not isinstance(ds_cfg, DPODataset):
|
||||
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
||||
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
||||
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
||||
|
||||
Reference in New Issue
Block a user