Scheduler implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (#1273)
This commit is contained in:
@@ -50,6 +50,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -164,6 +165,12 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=None,
|
||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||
)
|
||||
cosine_constant_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -221,6 +228,16 @@ class AxolotlTrainer(Trainer):
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||
@@ -887,6 +904,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||
)
|
||||
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
||||
training_arguments_kwargs[
|
||||
"cosine_constant_lr_ratio"
|
||||
] = self.cfg.cosine_constant_lr_ratio
|
||||
training_arguments_kwargs["weight_decay"] = (
|
||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||
)
|
||||
|
||||
@@ -52,7 +52,7 @@ def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
|
||||
*,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
num_cycles: float
|
||||
num_cycles: float,
|
||||
):
|
||||
if current_step < num_warmup_steps:
|
||||
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
|
||||
@@ -107,7 +107,7 @@ def _get_cosine_schedule_with_min_lr_lambda(
|
||||
*,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
min_lr_ratio: float
|
||||
min_lr_ratio: float,
|
||||
):
|
||||
# Warm up
|
||||
if current_step < num_warmup_steps:
|
||||
@@ -140,3 +140,80 @@ def get_cosine_schedule_with_min_lr(
|
||||
min_lr_ratio=min_lr_ratio,
|
||||
)
|
||||
return LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
|
||||
def _get_cosine_schedule_with_warmup_decay_constant_lr_lambda(
|
||||
current_step: int,
|
||||
*,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
constant_lr_ratio: float,
|
||||
min_lr_ratio: float,
|
||||
num_cycles: float,
|
||||
):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
|
||||
num_constant_steps = int(num_training_steps * constant_lr_ratio)
|
||||
current_step = min(current_step, num_constant_steps)
|
||||
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_constant_steps - num_warmup_steps)
|
||||
)
|
||||
|
||||
return (
|
||||
max(
|
||||
0,
|
||||
(1 - min_lr_ratio)
|
||||
* 0.5
|
||||
* (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
|
||||
)
|
||||
+ min_lr_ratio
|
||||
)
|
||||
|
||||
|
||||
def get_cosine_schedule_with_warmup_decay_constant(
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
constant_lr_ratio: float,
|
||||
min_lr_ratio: float,
|
||||
num_cycles: float = 0.5,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
"""
|
||||
Implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf)
|
||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
initial lr set in the optimizer to min_lr_ratio until num_training_steps * constant_lr_ratio, after constant_rate returns constant value of min_rate
|
||||
, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.
|
||||
|
||||
Args:
|
||||
optimizer ([`~torch.optim.Optimizer`]):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
num_warmup_steps (`int`):
|
||||
The number of steps for the warmup phase.
|
||||
num_training_steps (`int`):
|
||||
The total number of training steps.
|
||||
constant_lr_ratio: (`float`):
|
||||
The ratio of num_training_steps to decrease by cosine function.
|
||||
min_lr_ratio: (`float):
|
||||
The ratio of maximum learning rate for cosine function to decay to minimum learning rate.
|
||||
num_cycles (`float`, *optional*, defaults to 0.5):
|
||||
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
||||
following a half-cosine).
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
|
||||
Return:
|
||||
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||
"""
|
||||
|
||||
lr_lambda = partial(
|
||||
_get_cosine_schedule_with_warmup_decay_constant_lr_lambda,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
constant_lr_ratio=constant_lr_ratio,
|
||||
min_lr_ratio=min_lr_ratio,
|
||||
num_cycles=num_cycles,
|
||||
)
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
Reference in New Issue
Block a user