add mistral instruct strategy and fix dpo_loss input
This commit is contained in:
30
src/axolotl/prompt_strategies/dpo/mistral.py
Normal file
30
src/axolotl/prompt_strategies/dpo/mistral.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
DPO strategies for mistral instruct
|
||||
"""
|
||||
|
||||
|
||||
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
|
||||
sample["chosen"] = f"{sample['chosen']}"
|
||||
sample["rejected"] = f"{sample['rejected']}"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
|
||||
def argilla_chat(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
for argilla/dpo-mix-7k conversations
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
|
||||
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
|
||||
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
@@ -575,6 +575,7 @@ class AxolotlInputConfig(
|
||||
neftune_noise_alpha: Optional[float] = None
|
||||
|
||||
orpo_alpha: Optional[float] = None
|
||||
dpo_beta: Optional[float] = None
|
||||
|
||||
max_memory: Optional[
|
||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||
|
||||
Reference in New Issue
Block a user