add mistral instruct strategy and fix dpo_loss input

This commit is contained in:
Wing Lian
2024-05-02 23:02:03 -04:00
parent f58fcd09ec
commit 0554105baa
2 changed files with 31 additions and 0 deletions

View File

@@ -0,0 +1,30 @@
"""
DPO strategies for mistral instruct
"""
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
sample["chosen"] = f"{sample['chosen']}"
sample["rejected"] = f"{sample['rejected']}"
return sample
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""
def transform_fn(sample):
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
return sample
return transform_fn

View File

@@ -575,6 +575,7 @@ class AxolotlInputConfig(
neftune_noise_alpha: Optional[float] = None
orpo_alpha: Optional[float] = None
dpo_beta: Optional[float] = None
max_memory: Optional[
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]