add support for simpo via cpo trainer (#1772)
* add support for simpo via cpo trainer * add cpo_alpha / sft_weight from the paper * make sure to use the right builder for simpo
This commit is contained in:
@@ -30,7 +30,16 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
|
from trl import (
|
||||||
|
CPOConfig,
|
||||||
|
CPOTrainer,
|
||||||
|
DPOConfig,
|
||||||
|
DPOTrainer,
|
||||||
|
KTOConfig,
|
||||||
|
KTOTrainer,
|
||||||
|
ORPOConfig,
|
||||||
|
ORPOTrainer,
|
||||||
|
)
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
@@ -265,6 +274,18 @@ class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
||||||
|
"""
|
||||||
|
CPO config for CPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
simpo_gamma: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "simpo gamma parameter"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
"""
|
"""
|
||||||
Extend the base Trainer for axolotl helpers
|
Extend the base Trainer for axolotl helpers
|
||||||
@@ -985,6 +1006,14 @@ class AxolotlKTOTrainer(KTOTrainer):
|
|||||||
tag_names = ["axolotl", "kto"]
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlCPOTrainer(CPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
Base class for trainer builder
|
Base class for trainer builder
|
||||||
@@ -1707,6 +1736,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_args_kwargs["save_strategy"] = "epoch"
|
training_args_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
|
if self.cfg.rl_beta:
|
||||||
|
training_args_kwargs["beta"] = self.cfg.rl_beta
|
||||||
if self.cfg.orpo_alpha:
|
if self.cfg.orpo_alpha:
|
||||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
@@ -1715,9 +1746,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_cls = AxolotlDPOConfig
|
training_args_cls = AxolotlDPOConfig
|
||||||
if self.cfg.rpo_alpha is not None:
|
if self.cfg.rpo_alpha is not None:
|
||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||||
|
|
||||||
|
if self.cfg.rl == "simpo":
|
||||||
|
training_args_cls = AxolotlCPOConfig
|
||||||
|
training_args_kwargs["loss_type"] = "simpo"
|
||||||
|
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
|
||||||
|
if self.cfg.cpo_alpha is not None:
|
||||||
|
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||||
|
|
||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
training_args_cls = AxolotlORPOConfig
|
training_args_cls = AxolotlORPOConfig
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
if self.cfg.max_prompt_len:
|
if self.cfg.max_prompt_len:
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
@@ -1725,7 +1763,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rl == "kto":
|
if self.cfg.rl == "kto":
|
||||||
training_args_cls = AxolotlKTOConfig
|
training_args_cls = AxolotlKTOConfig
|
||||||
|
|
||||||
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
|
||||||
training_args_kwargs["desirable_weight"] = (
|
training_args_kwargs["desirable_weight"] = (
|
||||||
self.cfg.kto_desirable_weight or 1.0
|
self.cfg.kto_desirable_weight or 1.0
|
||||||
)
|
)
|
||||||
@@ -1771,7 +1808,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo"]:
|
if self.cfg.rl in ["dpo", "ipo"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|
||||||
# these aren't used for the ORPO trainer
|
# these aren't used for the ORPO trainer
|
||||||
@@ -1785,6 +1821,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
elif self.cfg.rl in ["kto"]:
|
elif self.cfg.rl in ["kto"]:
|
||||||
trainer_cls = AxolotlKTOTrainer
|
trainer_cls = AxolotlKTOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
elif self.cfg.rl in ["simpo"]:
|
||||||
|
trainer_cls = AxolotlCPOTrainer
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
dpo_trainer = trainer_cls(
|
dpo_trainer = trainer_cls(
|
||||||
|
|||||||
@@ -172,6 +172,7 @@ class RLType(str, Enum):
|
|||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
kto = "kto" # pylint: disable=invalid-name
|
kto = "kto" # pylint: disable=invalid-name
|
||||||
|
simpo = "simpo" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate(str, Enum):
|
class ChatTemplate(str, Enum):
|
||||||
@@ -644,6 +645,8 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
orpo_alpha: Optional[float] = None
|
orpo_alpha: Optional[float] = None
|
||||||
rpo_alpha: Optional[float] = None
|
rpo_alpha: Optional[float] = None
|
||||||
|
simpo_gamma: Optional[float] = None
|
||||||
|
cpo_alpha: Optional[float] = None
|
||||||
|
|
||||||
kto_desirable_weight: Optional[float] = None
|
kto_desirable_weight: Optional[float] = None
|
||||||
kto_undesirable_weight: Optional[float] = None
|
kto_undesirable_weight: Optional[float] = None
|
||||||
|
|||||||
@@ -425,7 +425,7 @@ def prepare_optim_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.rl in ["dpo", "ipo", "orpo", "kto"]:
|
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
|
|||||||
Reference in New Issue
Block a user