Compare commits

...

8 Commits
3181 ... sppo

Author SHA1 Message Date
Wing Lian
6a9ac4ad27 consistency w sppo -> sppo_hard 2024-05-06 16:58:58 -04:00
Wing Lian
027f7d54f0 update for sppo 2024-05-06 16:55:46 -04:00
Wing Lian
0554105baa add mistral instruct strategy and fix dpo_loss input 2024-05-06 16:55:18 -04:00
Wing Lian
f58fcd09ec use DPOConfig 2024-05-06 16:55:16 -04:00
Wing Lian
60fecac367 bump trl 2024-05-06 16:54:03 -04:00
Wing Lian
b301068098 remove override 2024-05-06 16:54:02 -04:00
Wing Lian
df645906eb invert check 2024-05-06 16:54:02 -04:00
Wing Lian
7fea5822f0 add support for SPPO 2024-05-06 16:54:02 -04:00
7 changed files with 47 additions and 6 deletions

View File

@@ -138,7 +138,7 @@ test_datasets:
data_files:
- /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto_pair'
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard'
rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing

View File

@@ -39,6 +39,6 @@ s3fs
gcsfs
# adlfs
trl==0.8.5
trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c
zstandard==0.22.0
fastcore

View File

@@ -30,7 +30,7 @@ from transformers import (
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer
from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer
@@ -1526,6 +1526,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
training_args_cls = DPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args = training_args_cls(
per_device_train_batch_size=self.cfg.micro_batch_size,
@@ -1552,6 +1555,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
elif self.cfg.rl == "kto_pair":
dpo_trainer_kwargs["loss_type"] = "kto_pair"
elif self.cfg.rl == "sppo_hard":
dpo_trainer_kwargs["loss_type"] = "sppo_hard"
if self.eval_dataset:
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
@@ -1560,7 +1565,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
trainer_cls_args = [self.model, self.model_ref]

View File

@@ -0,0 +1,30 @@
"""
DPO strategies for mistral instruct
"""
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
sample["chosen"] = f"{sample['chosen']}"
sample["rejected"] = f"{sample['rejected']}"
return sample
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""
def transform_fn(sample):
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
return sample
return transform_fn

View File

@@ -133,6 +133,7 @@ class RLType(str, Enum):
ipo = "ipo" # pylint: disable=invalid-name
kto_pair = "kto_pair" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
sppo_hard = "sppo_hard" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
@@ -574,6 +575,7 @@ class AxolotlInputConfig(
neftune_noise_alpha: Optional[float] = None
orpo_alpha: Optional[float] = None
dpo_beta: Optional[float] = None
max_memory: Optional[
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]

View File

@@ -789,7 +789,11 @@ def load_model(
if not reference_model or cfg.lora_model_dir:
# if we're not loading the reference model, then we're loading the model for training
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
if (
cfg.adapter
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]
and not cfg.merge_lora
):
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
else:
model, lora_config = load_adapter(model, cfg, cfg.adapter)

View File

@@ -438,7 +438,7 @@ def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]