diff --git a/src/axolotl/core/trainer_builder/base.py b/src/axolotl/core/trainer_builder/base.py index 4364449c5..878718199 100644 --- a/src/axolotl/core/trainer_builder/base.py +++ b/src/axolotl/core/trainer_builder/base.py @@ -326,7 +326,7 @@ class TrainerBuilderBase(abc.ABC): # optim/scheduler training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding - if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: + if self.cfg.lr_scheduler in ["one_cycle", "log_sweep", "rex"]: training_args_kwargs["lr_scheduler_type"] = "cosine" training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler else: