fix learning rate scheduler's warnings (#1135) [skip ci]
* fix schedulers warnings * chore: lint --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
committed by
GitHub
parent
98b4762077
commit
b4ac96adef
@@ -170,24 +170,30 @@ class AxolotlTrainer(Trainer):
|
|||||||
num_training_steps (int): The number of training steps to do.
|
num_training_steps (int): The number of training steps to do.
|
||||||
optimizer (torch.optim.Optimizer): The training optimizer
|
optimizer (torch.optim.Optimizer): The training optimizer
|
||||||
"""
|
"""
|
||||||
|
use_cosine_quadratic = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.lr_quadratic_warmup is True
|
||||||
|
)
|
||||||
|
|
||||||
|
use_cosine_min_lr = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.cosine_min_lr_ratio is not None
|
||||||
|
)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if (
|
if use_cosine_quadratic:
|
||||||
self.args.lr_scheduler_type == "cosine"
|
if use_cosine_min_lr:
|
||||||
and self.args.lr_quadratic_warmup is True
|
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||||
):
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||||
optimizer,
|
optimizer,
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
num_training_steps=num_training_steps,
|
num_training_steps=num_training_steps,
|
||||||
)
|
)
|
||||||
elif self.args.lr_scheduler_type == "cosine" and self.args.cosine_min_lr_ratio is not None:
|
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
if self.args.deepspeed:
|
|
||||||
LOG.warning("Using cosine scheduler with deepspeed. This may be ignored if a scheduler is set \
|
|
||||||
in the deepspeed JSON")
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||||
optimizer,
|
optimizer,
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
@@ -196,6 +202,13 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
else:
|
||||||
|
if use_cosine_quadratic:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
|
|||||||
Reference in New Issue
Block a user