add mistral instruct strategy and fix dpo_loss input
This commit is contained in:
30
src/axolotl/prompt_strategies/dpo/mistral.py
Normal file
30
src/axolotl/prompt_strategies/dpo/mistral.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""
|
||||||
|
DPO strategies for mistral instruct
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(sample):
|
||||||
|
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
|
||||||
|
sample["chosen"] = f"{sample['chosen']}"
|
||||||
|
sample["rejected"] = f"{sample['rejected']}"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def argilla_chat(
|
||||||
|
cfg,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
for argilla/dpo-mix-7k conversations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
|
||||||
|
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
|
||||||
|
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
@@ -575,6 +575,7 @@ class AxolotlInputConfig(
|
|||||||
neftune_noise_alpha: Optional[float] = None
|
neftune_noise_alpha: Optional[float] = None
|
||||||
|
|
||||||
orpo_alpha: Optional[float] = None
|
orpo_alpha: Optional[float] = None
|
||||||
|
dpo_beta: Optional[float] = None
|
||||||
|
|
||||||
max_memory: Optional[
|
max_memory: Optional[
|
||||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||||
|
|||||||
Reference in New Issue
Block a user