diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 8c13eb78d..15ad71470 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -170,24 +170,30 @@ class AxolotlTrainer(Trainer): num_training_steps (int): The number of training steps to do. optimizer (torch.optim.Optimizer): The training optimizer """ + use_cosine_quadratic = ( + self.args.lr_scheduler_type == "cosine" + and self.args.lr_quadratic_warmup is True + ) + + use_cosine_min_lr = ( + self.args.lr_scheduler_type == "cosine" + and self.args.cosine_min_lr_ratio is not None + ) # fmt: off if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition # fmt: on - if ( - self.args.lr_scheduler_type == "cosine" - and self.args.lr_quadratic_warmup is True - ): + if use_cosine_quadratic: + if use_cosine_min_lr: + LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") + self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init optimizer, num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_training_steps=num_training_steps, ) - elif self.args.lr_scheduler_type == "cosine" and self.args.cosine_min_lr_ratio is not None: + elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - if self.args.deepspeed: - LOG.warning("Using cosine scheduler with deepspeed. This may be ignored if a scheduler is set \ - in the deepspeed JSON") self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init optimizer, num_warmup_steps=self.args.get_warmup_steps(num_training_steps), @@ -196,6 +202,13 @@ class AxolotlTrainer(Trainer): ) else: return super().create_scheduler(num_training_steps, optimizer) + else: + if use_cosine_quadratic: + LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") + + if use_cosine_min_lr: + LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") + return self.lr_scheduler def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: