Support user-defined prompt processing strategies for dpo (#1248)
* support user-defined prompt processing strategies for dpo * interpret dict dataset types as user-defined * fix lint errors * setup pydantic config for validation of User defined DPO --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -8,14 +8,13 @@ import logging
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def load(strategy, cfg):
|
def load(strategy, cfg, **kwargs):
|
||||||
try:
|
try:
|
||||||
load_fn = strategy.split(".")[-1]
|
load_fn = strategy.split(".")[-1]
|
||||||
strategy = ".".join(strategy.split(".")[:-1])
|
strategy = ".".join(strategy.split(".")[:-1])
|
||||||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
|
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
|
||||||
func = getattr(mod, load_fn)
|
func = getattr(mod, load_fn)
|
||||||
load_kwargs = {}
|
return func(cfg, **kwargs)
|
||||||
return func(cfg, **load_kwargs)
|
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
LOG.warning(f"unable to load strategy {strategy}")
|
LOG.warning(f"unable to load strategy {strategy}")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ DPO strategies for chatml
|
|||||||
|
|
||||||
def argilla(
|
def argilla(
|
||||||
cfg,
|
cfg,
|
||||||
|
**kwargs,
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
if "system" in sample and sample["system"]:
|
if "system" in sample and sample["system"]:
|
||||||
@@ -25,6 +26,7 @@ def argilla(
|
|||||||
|
|
||||||
def icr(
|
def icr(
|
||||||
cfg,
|
cfg,
|
||||||
|
**kwargs,
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
"""
|
"""
|
||||||
chatml transforms for datasets with system, input, chosen, rejected
|
chatml transforms for datasets with system, input, chosen, rejected
|
||||||
@@ -48,7 +50,7 @@ def icr(
|
|||||||
return transform_fn
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
"""
|
"""
|
||||||
For Intel Orca DPO Pairs
|
For Intel Orca DPO Pairs
|
||||||
"""
|
"""
|
||||||
@@ -70,7 +72,9 @@ def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
|||||||
return transform_fn
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
def prompt_pairs(
|
||||||
|
cfg, **kwargs
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
if "system" in sample and sample["system"]:
|
if "system" in sample and sample["system"]:
|
||||||
sample["prompt"] = (
|
sample["prompt"] = (
|
||||||
@@ -88,7 +92,7 @@ def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argume
|
|||||||
return transform_fn
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
"""
|
"""
|
||||||
for ultrafeedback binarized conversations
|
for ultrafeedback binarized conversations
|
||||||
"""
|
"""
|
||||||
|
|||||||
41
src/axolotl/prompt_strategies/dpo/user_defined.py
Normal file
41
src/axolotl/prompt_strategies/dpo/user_defined.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""
|
||||||
|
User-defined DPO strategies
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
|
||||||
|
ds_cfg = cfg["datasets"][dataset_idx]["type"]
|
||||||
|
if not isinstance(ds_cfg, dict):
|
||||||
|
raise ValueError(
|
||||||
|
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
|
||||||
|
)
|
||||||
|
field_prompt = ds_cfg.get("field_prompt", "prompt")
|
||||||
|
field_system = ds_cfg.get("field_system", "system")
|
||||||
|
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||||
|
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||||
|
prompt_format = ds_cfg.get("prompt_format")
|
||||||
|
if not prompt_format:
|
||||||
|
prompt_format = "{" + field_prompt + "}"
|
||||||
|
chosen_format = ds_cfg.get("chosen_format")
|
||||||
|
if not chosen_format:
|
||||||
|
chosen_format = "{" + field_chosen + "}"
|
||||||
|
rejected_format = ds_cfg.get("rejected_format")
|
||||||
|
if not rejected_format:
|
||||||
|
rejected_format = "{" + field_rejected + "}"
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
if (
|
||||||
|
"{" + field_system + "}" in prompt_format
|
||||||
|
and field_system in sample
|
||||||
|
and sample[field_system]
|
||||||
|
):
|
||||||
|
sample["prompt"] = prompt_format.format(
|
||||||
|
system=sample[field_system], prompt=sample[field_prompt]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
|
||||||
|
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
|
||||||
|
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
@@ -3,7 +3,7 @@ DPO strategies for zephyr
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def nectar(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
def nectar(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
data = {}
|
data = {}
|
||||||
data["prompt"] = (
|
data["prompt"] = (
|
||||||
|
|||||||
@@ -85,12 +85,24 @@ class SFTDataset(BaseModel):
|
|||||||
field_model: Optional[str] = None
|
field_model: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedDPOType(BaseModel):
|
||||||
|
"""User defined typing for DPO"""
|
||||||
|
|
||||||
|
field_system: Optional[str] = None
|
||||||
|
field_prompt: Optional[str] = None
|
||||||
|
field_chosen: Optional[str] = None
|
||||||
|
field_rejected: Optional[str] = None
|
||||||
|
prompt_format: Optional[str] = None
|
||||||
|
chosen_format: Optional[str] = None
|
||||||
|
rejected_format: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class DPODataset(BaseModel):
|
class DPODataset(BaseModel):
|
||||||
"""DPO configuration subset"""
|
"""DPO configuration subset"""
|
||||||
|
|
||||||
path: Optional[str] = None
|
path: Optional[str] = None
|
||||||
split: Optional[str] = None
|
split: Optional[str] = None
|
||||||
type: Optional[str] = None
|
type: Optional[Union[UserDefinedDPOType, str]] = None
|
||||||
data_files: Optional[List[str]] = None
|
data_files: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -937,7 +937,9 @@ def load_prepare_dpo_datasets(cfg):
|
|||||||
for i, data_set in enumerate(split_datasets):
|
for i, data_set in enumerate(split_datasets):
|
||||||
_type = dataset_cfgs[i]["type"]
|
_type = dataset_cfgs[i]["type"]
|
||||||
if _type:
|
if _type:
|
||||||
ds_transform_fn = load_dpo(_type, _cfg)
|
if isinstance(_type, DictDefault):
|
||||||
|
_type = "user_defined.default"
|
||||||
|
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||||
split_datasets[i] = data_set.map(
|
split_datasets[i] = data_set.map(
|
||||||
ds_transform_fn,
|
ds_transform_fn,
|
||||||
desc="Mapping RL Dataset",
|
desc="Mapping RL Dataset",
|
||||||
|
|||||||
Reference in New Issue
Block a user