fix: customized dataset with simpo (#2894) [skip ci]
This commit is contained in:
@@ -274,15 +274,14 @@ rl: dpo
|
|||||||
datasets:
|
datasets:
|
||||||
- path: ...
|
- path: ...
|
||||||
split: train
|
split: train
|
||||||
type: user_defined.default
|
type:
|
||||||
|
field_prompt: "prompt"
|
||||||
field_prompt: "prompt"
|
field_system: "system"
|
||||||
field_system: "system"
|
field_chosen: "chosen"
|
||||||
field_chosen: "chosen"
|
field_rejected: "rejected"
|
||||||
field_rejected: "rejected"
|
prompt_format: "{prompt}"
|
||||||
prompt_format: "{prompt}"
|
chosen_format: "{chosen}"
|
||||||
chosen_format: "{chosen}"
|
rejected_format: "{rejected}"
|
||||||
rejected_format: "{rejected}"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The input format is a simple JSON input with customizable fields based on the above config.
|
The input format is a simple JSON input with customizable fields based on the above config.
|
||||||
@@ -475,14 +474,13 @@ rl: kto
|
|||||||
datasets:
|
datasets:
|
||||||
- path: ...
|
- path: ...
|
||||||
split: train
|
split: train
|
||||||
type: user_defined.default
|
type:
|
||||||
|
field_prompt: "prompt"
|
||||||
field_prompt: "prompt"
|
field_system: "system"
|
||||||
field_system: "system"
|
field_completion: "completion"
|
||||||
field_completion: "completion"
|
field_label: "label"
|
||||||
field_label: "label"
|
prompt_format: "{prompt}"
|
||||||
prompt_format: "{prompt}"
|
completion_format: "{completion}"
|
||||||
completion_format: "{completion}"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The input format is a simple JSON input with customizable fields based on the above config.
|
The input format is a simple JSON input with customizable fields based on the above config.
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
|
|||||||
system=sample[field_system], prompt=sample[field_prompt]
|
system=sample[field_system], prompt=sample[field_prompt]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
|
sample["prompt"] = prompt_format.format(prompt=sample[field_prompt])
|
||||||
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
|
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
|
||||||
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
|
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
|
||||||
return sample
|
return sample
|
||||||
|
|||||||
@@ -274,7 +274,7 @@ def validate_config(
|
|||||||
# Convert datasets to proper format if needed
|
# Convert datasets to proper format if needed
|
||||||
if cfg.get("datasets"):
|
if cfg.get("datasets"):
|
||||||
for idx, ds_cfg in enumerate(cfg["datasets"]):
|
for idx, ds_cfg in enumerate(cfg["datasets"]):
|
||||||
if cfg.get("rl") == "dpo" and not isinstance(ds_cfg, DPODataset):
|
if cfg.get("rl") in ["dpo", "simpo"] and not isinstance(ds_cfg, DPODataset):
|
||||||
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
||||||
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
||||||
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
||||||
|
|||||||
Reference in New Issue
Block a user