From a27d5e1f4e36ae8faa0e60394e6e36dd1ee67ff6 Mon Sep 17 00:00:00 2001 From: George Grigorev Date: Wed, 22 May 2024 13:29:06 +0100 Subject: [PATCH] enable loraplus setting for dpo trainer (#1646) --- src/axolotl/core/trainer_builder.py | 38 ++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) mode change 100644 => 100755 src/axolotl/core/trainer_builder.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py old mode 100644 new mode 100755 index 06b71b3e1..c510c8e10 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -798,6 +798,40 @@ class AxolotlDPOTrainer(DPOTrainer): tag_names = ["axolotl", "dpo"] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.optimizer = None + + def create_optimizer(self): + if self.args.loraplus_lr_ratio is None: + return super().create_optimizer() + + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if self.optimizer is None: # pylint: disable=access-member-before-definition + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( + self.args, + opt_model, + ) + + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + if loraplus_lr_ratio: + print("Using lora+") + loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + optimizer_kwargs, + loraplus_lr_ratio, + loraplus_lr_embedding, + ) + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init + self.optimizer + ) + + return self.optimizer + @wraps(DPOTrainer.push_to_hub) def push_to_hub(self, *args, **kwargs) -> str: """ @@ -1483,6 +1517,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.bf16 or self.cfg.bfloat16: training_args_kwargs["bf16"] = True + training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio + training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding training_args_kwargs["lr_scheduler_type"] = ( self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" ) @@ -1535,7 +1571,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? training_args_kwargs["beta"] = self.cfg.orpo_alpha - training_args_cls = TrainingArguments + training_args_cls = AxolotlTrainingArguments if self.cfg.rl == "orpo": training_args_cls = ORPOConfig training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes