Compare commits
1 Commits
jagged-res
...
wait-distr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
459f407e69 |
@@ -8,7 +8,6 @@ from transformers.trainer import Trainer
|
|||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
JaggedLRRestartScheduler,
|
|
||||||
RexLR,
|
RexLR,
|
||||||
get_cosine_schedule_with_min_lr,
|
get_cosine_schedule_with_min_lr,
|
||||||
get_cosine_schedule_with_quadratic_warmup,
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
@@ -113,22 +112,7 @@ class SchedulerMixin(Trainer):
|
|||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
super().create_scheduler(num_training_steps, optimizer=optimizer)
|
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||||
if self.args.jagged_restart_steps:
|
|
||||||
warmup_steps = (
|
|
||||||
self.args.jagged_restarts_warmup_steps or 10
|
|
||||||
)
|
|
||||||
anneal_steps = (
|
|
||||||
self.args.jagged_restarts_anneal_steps or 1
|
|
||||||
)
|
|
||||||
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
self.lr_scheduler,
|
|
||||||
self.args.jagged_restart_steps,
|
|
||||||
warmup_steps,
|
|
||||||
anneal_steps,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if use_cosine_quadratic:
|
if use_cosine_quadratic:
|
||||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
|
|||||||
@@ -86,22 +86,6 @@ class AxolotlTrainingMixins:
|
|||||||
default=0.9,
|
default=0.9,
|
||||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||||
)
|
)
|
||||||
jagged_restart_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how often to reset for jagged restarts"},
|
|
||||||
)
|
|
||||||
jagged_restarts_warmup_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "how many warmup steps to take after reset for jagged restarts"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
jagged_restarts_anneal_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "how many anneal steps to take before reset for jagged restarts"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
bench_split: Optional[str] = field(
|
bench_split: Optional[str] = field(
|
||||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -289,16 +289,18 @@ def save_trained_model(
|
|||||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
elif cfg.local_rank == 0:
|
else:
|
||||||
if cfg.flash_optimum and BetterTransformer:
|
if cfg.local_rank == 0:
|
||||||
model = BetterTransformer.reverse(model)
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
|
model = BetterTransformer.reverse(model)
|
||||||
|
|
||||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||||
trainer.model.save_pretrained(
|
trainer.model.save_pretrained(
|
||||||
cfg.output_dir, safe_serialization=safe_serialization
|
cfg.output_dir, safe_serialization=safe_serialization
|
||||||
)
|
)
|
||||||
|
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
trainer.accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
||||||
# TODO: add integration support so this can be implemented completely within the plugin
|
# TODO: add integration support so this can be implemented completely within the plugin
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||||
@@ -293,47 +292,3 @@ def get_cosine_schedule_with_warmup_decay_constant(
|
|||||||
num_cycles=num_cycles,
|
num_cycles=num_cycles,
|
||||||
)
|
)
|
||||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
class JaggedLRRestartScheduler(LRScheduler):
|
|
||||||
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
optimizer: Optimizer,
|
|
||||||
inner_schedule: LRScheduler,
|
|
||||||
jagged_restarts_steps: int,
|
|
||||||
jagged_restarts_warmup_steps: int,
|
|
||||||
jagged_restarts_anneal_steps: int = 1,
|
|
||||||
min_lr_scale: float = 0.001,
|
|
||||||
) -> None:
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
self.inner_schedule = inner_schedule
|
|
||||||
self.restarts_steps = jagged_restarts_steps
|
|
||||||
self.warmup_steps = jagged_restarts_warmup_steps
|
|
||||||
self.anneal_steps = jagged_restarts_anneal_steps
|
|
||||||
self.min_lr_scale = min_lr_scale
|
|
||||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
|
||||||
|
|
||||||
def get_lr(self) -> List[float]:
|
|
||||||
self.inner_schedule.last_epoch = self.last_epoch
|
|
||||||
|
|
||||||
original: List[float] = self.inner_schedule.get_lr()
|
|
||||||
step = self.last_epoch
|
|
||||||
|
|
||||||
if step < self.restarts_steps:
|
|
||||||
scale = 1
|
|
||||||
else:
|
|
||||||
per_restart_progress = step % self.restarts_steps
|
|
||||||
if per_restart_progress < self.warmup_steps:
|
|
||||||
cycle_t = min(1.0, per_restart_progress / self.warmup_steps)
|
|
||||||
elif per_restart_progress > (self.restarts_steps - self.anneal_steps):
|
|
||||||
cycle_t = min(
|
|
||||||
1.0,
|
|
||||||
(self.restarts_steps - per_restart_progress) / self.anneal_steps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cycle_t = 1
|
|
||||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
|
||||||
|
|
||||||
return original * scale
|
|
||||||
|
|||||||
Reference in New Issue
Block a user