fix learning rate scheduler's warnings (#1135) [skip ci]

* fix schedulers warnings

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Ricardo Dominguez-Olmedo
2024-01-25 13:09:34 +01:00
committed by GitHub
parent 98b4762077
commit b4ac96adef

View File

@@ -170,24 +170,30 @@ class AxolotlTrainer(Trainer):
num_training_steps (int): The number of training steps to do. num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer optimizer (torch.optim.Optimizer): The training optimizer
""" """
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off # fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on # fmt: on
if ( if use_cosine_quadratic:
self.args.lr_scheduler_type == "cosine" if use_cosine_min_lr:
and self.args.lr_quadratic_warmup is True LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer, optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps, num_training_steps=num_training_steps,
) )
elif self.args.lr_scheduler_type == "cosine" and self.args.cosine_min_lr_ratio is not None: elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
if self.args.deepspeed:
LOG.warning("Using cosine scheduler with deepspeed. This may be ignored if a scheduler is set \
in the deepspeed JSON")
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer, optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
@@ -196,6 +202,13 @@ class AxolotlTrainer(Trainer):
) )
else: else:
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: