Support user-defined prompt processing strategies for dpo (#1248)

* support user-defined prompt processing strategies for dpo

* interpret dict dataset types as user-defined

* fix lint errors

* setup pydantic config for validation of User defined DPO

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
nopperl
2024-02-26 23:49:34 +00:00
committed by GitHub
parent 16482796b0
commit 1e3d5305d3
6 changed files with 67 additions and 9 deletions

View File

@@ -8,14 +8,13 @@ import logging
LOG = logging.getLogger("axolotl")
def load(strategy, cfg):
def load(strategy, cfg, **kwargs):
try:
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
func = getattr(mod, load_fn)
load_kwargs = {}
return func(cfg, **load_kwargs)
return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}")
return None

View File

@@ -5,6 +5,7 @@ DPO strategies for chatml
def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
@@ -25,6 +26,7 @@ def argilla(
def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
@@ -48,7 +50,7 @@ def icr(
return transform_fn
def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca DPO Pairs
"""
@@ -70,7 +72,9 @@ def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
return transform_fn
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
@@ -88,7 +92,7 @@ def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argume
return transform_fn
def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
"""

View File

@@ -0,0 +1,41 @@
"""
User-defined DPO strategies
"""
def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
ds_cfg = cfg["datasets"][dataset_idx]["type"]
if not isinstance(ds_cfg, dict):
raise ValueError(
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
)
field_prompt = ds_cfg.get("field_prompt", "prompt")
field_system = ds_cfg.get("field_system", "system")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
prompt_format = ds_cfg.get("prompt_format")
if not prompt_format:
prompt_format = "{" + field_prompt + "}"
chosen_format = ds_cfg.get("chosen_format")
if not chosen_format:
chosen_format = "{" + field_chosen + "}"
rejected_format = ds_cfg.get("rejected_format")
if not rejected_format:
rejected_format = "{" + field_rejected + "}"
def transform_fn(sample):
if (
"{" + field_system + "}" in prompt_format
and field_system in sample
and sample[field_system]
):
sample["prompt"] = prompt_format.format(
system=sample[field_system], prompt=sample[field_prompt]
)
else:
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
return sample
return transform_fn

View File

@@ -3,7 +3,7 @@ DPO strategies for zephyr
"""
def nectar(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def nectar(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
data = {}
data["prompt"] = (

View File

@@ -85,12 +85,24 @@ class SFTDataset(BaseModel):
field_model: Optional[str] = None
class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""
field_system: Optional[str] = None
field_prompt: Optional[str] = None
field_chosen: Optional[str] = None
field_rejected: Optional[str] = None
prompt_format: Optional[str] = None
chosen_format: Optional[str] = None
rejected_format: Optional[str] = None
class DPODataset(BaseModel):
"""DPO configuration subset"""
path: Optional[str] = None
split: Optional[str] = None
type: Optional[str] = None
type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None

View File

@@ -937,7 +937,9 @@ def load_prepare_dpo_datasets(cfg):
for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
if _type:
ds_transform_fn = load_dpo(_type, _cfg)
if isinstance(_type, DictDefault):
_type = "user_defined.default"
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
split_datasets[i] = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",