Support user-defined prompt processing strategies for dpo (#1248)
* support user-defined prompt processing strategies for dpo * interpret dict dataset types as user-defined * fix lint errors * setup pydantic config for validation of User defined DPO --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -8,14 +8,13 @@ import logging
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def load(strategy, cfg):
|
||||
def load(strategy, cfg, **kwargs):
|
||||
try:
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
|
||||
func = getattr(mod, load_fn)
|
||||
load_kwargs = {}
|
||||
return func(cfg, **load_kwargs)
|
||||
return func(cfg, **kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"unable to load strategy {strategy}")
|
||||
return None
|
||||
|
||||
@@ -5,6 +5,7 @@ DPO strategies for chatml
|
||||
|
||||
def argilla(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
if "system" in sample and sample["system"]:
|
||||
@@ -25,6 +26,7 @@ def argilla(
|
||||
|
||||
def icr(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
chatml transforms for datasets with system, input, chosen, rejected
|
||||
@@ -48,7 +50,7 @@ def icr(
|
||||
return transform_fn
|
||||
|
||||
|
||||
def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
For Intel Orca DPO Pairs
|
||||
"""
|
||||
@@ -70,7 +72,9 @@ def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
return transform_fn
|
||||
|
||||
|
||||
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def prompt_pairs(
|
||||
cfg, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
@@ -88,7 +92,7 @@ def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argume
|
||||
return transform_fn
|
||||
|
||||
|
||||
def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
for ultrafeedback binarized conversations
|
||||
"""
|
||||
|
||||
41
src/axolotl/prompt_strategies/dpo/user_defined.py
Normal file
41
src/axolotl/prompt_strategies/dpo/user_defined.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
User-defined DPO strategies
|
||||
"""
|
||||
|
||||
|
||||
def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
|
||||
ds_cfg = cfg["datasets"][dataset_idx]["type"]
|
||||
if not isinstance(ds_cfg, dict):
|
||||
raise ValueError(
|
||||
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
|
||||
)
|
||||
field_prompt = ds_cfg.get("field_prompt", "prompt")
|
||||
field_system = ds_cfg.get("field_system", "system")
|
||||
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||
prompt_format = ds_cfg.get("prompt_format")
|
||||
if not prompt_format:
|
||||
prompt_format = "{" + field_prompt + "}"
|
||||
chosen_format = ds_cfg.get("chosen_format")
|
||||
if not chosen_format:
|
||||
chosen_format = "{" + field_chosen + "}"
|
||||
rejected_format = ds_cfg.get("rejected_format")
|
||||
if not rejected_format:
|
||||
rejected_format = "{" + field_rejected + "}"
|
||||
|
||||
def transform_fn(sample):
|
||||
if (
|
||||
"{" + field_system + "}" in prompt_format
|
||||
and field_system in sample
|
||||
and sample[field_system]
|
||||
):
|
||||
sample["prompt"] = prompt_format.format(
|
||||
system=sample[field_system], prompt=sample[field_prompt]
|
||||
)
|
||||
else:
|
||||
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
|
||||
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
|
||||
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
@@ -3,7 +3,7 @@ DPO strategies for zephyr
|
||||
"""
|
||||
|
||||
|
||||
def nectar(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def nectar(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
data = {}
|
||||
data["prompt"] = (
|
||||
|
||||
@@ -85,12 +85,24 @@ class SFTDataset(BaseModel):
|
||||
field_model: Optional[str] = None
|
||||
|
||||
|
||||
class UserDefinedDPOType(BaseModel):
|
||||
"""User defined typing for DPO"""
|
||||
|
||||
field_system: Optional[str] = None
|
||||
field_prompt: Optional[str] = None
|
||||
field_chosen: Optional[str] = None
|
||||
field_rejected: Optional[str] = None
|
||||
prompt_format: Optional[str] = None
|
||||
chosen_format: Optional[str] = None
|
||||
rejected_format: Optional[str] = None
|
||||
|
||||
|
||||
class DPODataset(BaseModel):
|
||||
"""DPO configuration subset"""
|
||||
|
||||
path: Optional[str] = None
|
||||
split: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
type: Optional[Union[UserDefinedDPOType, str]] = None
|
||||
data_files: Optional[List[str]] = None
|
||||
|
||||
|
||||
|
||||
@@ -937,7 +937,9 @@ def load_prepare_dpo_datasets(cfg):
|
||||
for i, data_set in enumerate(split_datasets):
|
||||
_type = dataset_cfgs[i]["type"]
|
||||
if _type:
|
||||
ds_transform_fn = load_dpo(_type, _cfg)
|
||||
if isinstance(_type, DictDefault):
|
||||
_type = "user_defined.default"
|
||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||
split_datasets[i] = data_set.map(
|
||||
ds_transform_fn,
|
||||
desc="Mapping RL Dataset",
|
||||
|
||||
Reference in New Issue
Block a user