jagged lr restart scheudler (#1680) [skip ci]
* jagged lr restart scheudler var name fix make sure to create scheduler first * wire things together * more fixes * fix for nesting scheduler and first anneal phase * no need for relora trainer anymore since we've generalized the relora scheduler * remove redundant relora scheduler and lint * update relora e2e test for updated params * need restart steps for relora test * update quarto docs for dropped relora trainer * update example yaml * drop verbose arg * min lr scale support for jagged lr * don't let min_lr be nonetype * cleanup args
This commit is contained in:
@@ -59,7 +59,6 @@ quartodoc:
|
|||||||
- core.trainers.base
|
- core.trainers.base
|
||||||
- core.trainers.trl
|
- core.trainers.trl
|
||||||
- core.trainers.mamba
|
- core.trainers.mamba
|
||||||
- core.trainers.relora
|
|
||||||
- core.trainers.dpo.trainer
|
- core.trainers.dpo.trainer
|
||||||
- core.trainers.grpo.trainer
|
- core.trainers.grpo.trainer
|
||||||
- core.trainers.grpo.sampler
|
- core.trainers.grpo.sampler
|
||||||
|
|||||||
@@ -25,9 +25,12 @@ lora_alpha: 16
|
|||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
|
|
||||||
relora_steps: 150
|
relora: true
|
||||||
relora_warmup_ratio: 0.1
|
relora_prune_ratio: 0.9
|
||||||
relora_cpu_offload: false
|
relora_cpu_offload: false
|
||||||
|
jagged_restart_steps: 150
|
||||||
|
jagged_restart_warmup_steps: 10
|
||||||
|
jagged_restart_anneal_steps: false
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from axolotl.core.trainers import (
|
|||||||
AxolotlPRMTrainer,
|
AxolotlPRMTrainer,
|
||||||
AxolotlRewardTrainer,
|
AxolotlRewardTrainer,
|
||||||
AxolotlTrainer,
|
AxolotlTrainer,
|
||||||
ReLoRATrainer,
|
|
||||||
)
|
)
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
@@ -58,7 +57,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
|
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora:
|
||||||
callbacks.append(ReLoRACallback(self.cfg))
|
callbacks.append(ReLoRACallback(self.cfg))
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -131,8 +130,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||||
if trainer_cls:
|
if trainer_cls:
|
||||||
return trainer_cls
|
return trainer_cls
|
||||||
if self.cfg.relora_steps:
|
|
||||||
return ReLoRATrainer
|
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
return AxolotlMambaTrainer
|
return AxolotlMambaTrainer
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
@@ -271,20 +268,25 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.sample_packing_eff_est
|
self.cfg.sample_packing_eff_est
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora and self.cfg.jagged_restart_steps:
|
||||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
|
||||||
training_arguments_kwargs["relora_warmup_steps"] = (
|
|
||||||
self.cfg.relora_warmup_steps
|
|
||||||
)
|
|
||||||
if self.cfg.relora_anneal_steps:
|
|
||||||
training_arguments_kwargs["relora_anneal_steps"] = (
|
|
||||||
self.cfg.relora_anneal_steps
|
|
||||||
)
|
|
||||||
if self.cfg.relora_prune_ratio:
|
if self.cfg.relora_prune_ratio:
|
||||||
training_arguments_kwargs["relora_prune_ratio"] = (
|
training_arguments_kwargs["relora_prune_ratio"] = (
|
||||||
self.cfg.relora_prune_ratio
|
self.cfg.relora_prune_ratio
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.cfg.jagged_restart_steps:
|
||||||
|
training_arguments_kwargs["jagged_restart_steps"] = (
|
||||||
|
self.cfg.jagged_restart_steps
|
||||||
|
)
|
||||||
|
if self.cfg.jagged_restart_warmup_steps:
|
||||||
|
training_arguments_kwargs["jagged_restart_warmup_steps"] = (
|
||||||
|
self.cfg.jagged_restart_warmup_steps
|
||||||
|
)
|
||||||
|
if self.cfg.jagged_restart_anneal_steps:
|
||||||
|
training_arguments_kwargs["jagged_restart_anneal_steps"] = (
|
||||||
|
self.cfg.jagged_restart_anneal_steps
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||||
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
||||||
training_arguments_kwargs["lisa_step_interval"] = (
|
training_arguments_kwargs["lisa_step_interval"] = (
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from .base import AxolotlTrainer
|
|||||||
from .dpo.trainer import AxolotlDPOTrainer
|
from .dpo.trainer import AxolotlDPOTrainer
|
||||||
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
|
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
|
||||||
from .mamba import AxolotlMambaTrainer
|
from .mamba import AxolotlMambaTrainer
|
||||||
from .relora import ReLoRATrainer
|
|
||||||
from .trl import (
|
from .trl import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from transformers.trainer import Trainer
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
|
JaggedLRRestartScheduler,
|
||||||
RexLR,
|
RexLR,
|
||||||
get_cosine_schedule_with_min_lr,
|
get_cosine_schedule_with_min_lr,
|
||||||
get_cosine_schedule_with_quadratic_warmup,
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
@@ -113,7 +114,7 @@ class SchedulerMixin(Trainer):
|
|||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||||
else:
|
else:
|
||||||
if use_cosine_quadratic:
|
if use_cosine_quadratic:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
@@ -123,4 +124,22 @@ class SchedulerMixin(Trainer):
|
|||||||
LOG.warning(
|
LOG.warning(
|
||||||
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
if self.args.jagged_restart_steps:
|
||||||
|
warmup_steps = (
|
||||||
|
self.args.jagged_restart_warmup_steps or 10
|
||||||
|
)
|
||||||
|
anneal_steps = (
|
||||||
|
self.args.jagged_restart_anneal_steps or 1
|
||||||
|
)
|
||||||
|
if not self.lr_scheduler:
|
||||||
|
super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
self.lr_scheduler,
|
||||||
|
self.args.jagged_restart_steps,
|
||||||
|
warmup_steps,
|
||||||
|
anneal_steps,
|
||||||
|
min_lr_scale=self.args.cosine_min_lr_ratio or 0.001,
|
||||||
|
)
|
||||||
|
|
||||||
return self.lr_scheduler # type: ignore
|
return self.lr_scheduler # type: ignore
|
||||||
|
|||||||
@@ -1,46 +0,0 @@
|
|||||||
"""Module for ReLoRA trainer"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
|
||||||
|
|
||||||
|
|
||||||
class ReLoRATrainer(AxolotlTrainer):
|
|
||||||
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "relora"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.lr_scheduler = None
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self,
|
|
||||||
num_training_steps: int,
|
|
||||||
optimizer: torch.optim.Optimizer | None = None,
|
|
||||||
) -> LRScheduler:
|
|
||||||
optimizer = self.optimizer if optimizer is None else optimizer
|
|
||||||
lr_scheduler: LRScheduler = super().create_scheduler(
|
|
||||||
num_training_steps, optimizer
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.relora_steps:
|
|
||||||
warmup_steps = (
|
|
||||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
|
||||||
)
|
|
||||||
anneal_steps = (
|
|
||||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
|
||||||
)
|
|
||||||
self.lr_scheduler = ReLoRAScheduler( # type: ignore
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler,
|
|
||||||
self.args.relora_steps,
|
|
||||||
anneal_steps,
|
|
||||||
warmup_steps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.lr_scheduler = lr_scheduler # type: ignore
|
|
||||||
|
|
||||||
return self.lr_scheduler # type: ignore
|
|
||||||
@@ -82,18 +82,26 @@ class AxolotlTrainingMixins:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "how often to reset for ReLoRA"},
|
metadata={"help": "how often to reset for ReLoRA"},
|
||||||
)
|
)
|
||||||
relora_warmup_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
|
||||||
)
|
|
||||||
relora_anneal_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
|
||||||
)
|
|
||||||
relora_prune_ratio: Optional[float] = field(
|
relora_prune_ratio: Optional[float] = field(
|
||||||
default=0.9,
|
default=0.9,
|
||||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||||
)
|
)
|
||||||
|
jagged_restart_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how often to reset for jagged restarts"},
|
||||||
|
)
|
||||||
|
jagged_restart_warmup_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "how many warmup steps to take after reset for jagged restarts"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
jagged_restart_anneal_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "how many anneal steps to take before reset for jagged restarts"
|
||||||
|
},
|
||||||
|
)
|
||||||
bench_split: Optional[str] = field(
|
bench_split: Optional[str] = field(
|
||||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import os.path
|
|||||||
import shutil
|
import shutil
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Sequence, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import peft
|
import peft
|
||||||
@@ -14,8 +14,6 @@ import safetensors.torch as st
|
|||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
|
||||||
from torch.optim.optimizer import Optimizer
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainerControl,
|
TrainerControl,
|
||||||
@@ -84,7 +82,7 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
|
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
|
||||||
|
|
||||||
def __init__(self, cfg: DictDefault):
|
def __init__(self, cfg: DictDefault):
|
||||||
self.relora_steps = cfg.relora_steps
|
self.relora_steps = cfg.jagged_restart_steps
|
||||||
self.cpu_offload = cfg.relora_cpu_offload
|
self.cpu_offload = cfg.relora_cpu_offload
|
||||||
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
|
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
|
||||||
self.last_full_model = cfg.base_model
|
self.last_full_model = cfg.base_model
|
||||||
@@ -255,51 +253,6 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class ReLoRAScheduler(LRScheduler):
|
|
||||||
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
optimizer: Optimizer,
|
|
||||||
inner_schedule: LRScheduler,
|
|
||||||
relora_steps: int,
|
|
||||||
warmup_steps: int,
|
|
||||||
anneal_steps: int = 1,
|
|
||||||
min_lr_scale: float = 0.001,
|
|
||||||
) -> None:
|
|
||||||
self.inner_schedule = inner_schedule
|
|
||||||
self.relora_steps = relora_steps
|
|
||||||
self.warmup_steps = warmup_steps
|
|
||||||
self.anneal_steps = anneal_steps
|
|
||||||
self.min_lr_scale = min_lr_scale
|
|
||||||
super().__init__(optimizer, inner_schedule.last_epoch)
|
|
||||||
|
|
||||||
def get_lr(self) -> float:
|
|
||||||
self.inner_schedule.last_epoch = self.last_epoch
|
|
||||||
|
|
||||||
original = self.inner_schedule.get_lr()
|
|
||||||
step = self.last_epoch
|
|
||||||
|
|
||||||
if step < self.relora_steps - self.warmup_steps:
|
|
||||||
scale = 1
|
|
||||||
else:
|
|
||||||
per_relora_progress = step % self.relora_steps
|
|
||||||
if per_relora_progress < self.warmup_steps:
|
|
||||||
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
|
|
||||||
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
|
|
||||||
cycle_t = min(
|
|
||||||
1.0,
|
|
||||||
(self.relora_steps - per_relora_progress) / self.anneal_steps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cycle_t = 1
|
|
||||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
|
||||||
|
|
||||||
if isinstance(original, Sequence):
|
|
||||||
return [lr * scale for lr in original]
|
|
||||||
return original * scale
|
|
||||||
|
|
||||||
|
|
||||||
def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
|
def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
|
||||||
model_name = "model.safetensors"
|
model_name = "model.safetensors"
|
||||||
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
|
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
|
||||||
|
|||||||
@@ -267,7 +267,7 @@ def save_trained_model(
|
|||||||
"your model weights with `axolotl quantize`."
|
"your model weights with `axolotl quantize`."
|
||||||
)
|
)
|
||||||
# Handle ReLoRA early return case
|
# Handle ReLoRA early return case
|
||||||
if cfg.relora_steps:
|
if cfg.relora:
|
||||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||||
@@ -292,3 +293,50 @@ def get_cosine_schedule_with_warmup_decay_constant(
|
|||||||
num_cycles=num_cycles,
|
num_cycles=num_cycles,
|
||||||
)
|
)
|
||||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
class JaggedLRRestartScheduler(LRScheduler):
|
||||||
|
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
inner_schedule: LRScheduler,
|
||||||
|
jagged_restart_steps: int,
|
||||||
|
jagged_restart_warmup_steps: int,
|
||||||
|
jagged_restart_anneal_steps: int = 1,
|
||||||
|
min_lr_scale: float = 0.001,
|
||||||
|
) -> None:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
self.inner_schedule = inner_schedule
|
||||||
|
self.restarts_steps = jagged_restart_steps
|
||||||
|
self.warmup_steps = jagged_restart_warmup_steps
|
||||||
|
self.anneal_steps = jagged_restart_anneal_steps
|
||||||
|
self.min_lr_scale = min_lr_scale
|
||||||
|
super().__init__(optimizer, inner_schedule.last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self) -> float | Sequence[float]:
|
||||||
|
self.inner_schedule.last_epoch = self.last_epoch
|
||||||
|
|
||||||
|
original = self.inner_schedule.get_lr()
|
||||||
|
step = self.last_epoch
|
||||||
|
|
||||||
|
if step < self.restarts_steps - self.anneal_steps:
|
||||||
|
scale = 1
|
||||||
|
else:
|
||||||
|
per_restart_progress = step % self.restarts_steps
|
||||||
|
if per_restart_progress < self.warmup_steps:
|
||||||
|
cycle_t = min(1.0, (per_restart_progress) / self.warmup_steps)
|
||||||
|
elif per_restart_progress > (self.restarts_steps - self.anneal_steps):
|
||||||
|
cycle_t = min(
|
||||||
|
1.0,
|
||||||
|
(self.restarts_steps - per_restart_progress) / self.anneal_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cycle_t = 1
|
||||||
|
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||||
|
|
||||||
|
if isinstance(original, Sequence):
|
||||||
|
return [lr * scale for lr in original]
|
||||||
|
|
||||||
|
return original * scale
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ from axolotl.utils.schemas.model import (
|
|||||||
from axolotl.utils.schemas.multimodal import MultiModalConfig
|
from axolotl.utils.schemas.multimodal import MultiModalConfig
|
||||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||||
from axolotl.utils.schemas.quantization import PTQConfig, QATConfig
|
from axolotl.utils.schemas.quantization import PTQConfig, QATConfig
|
||||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
from axolotl.utils.schemas.training import HyperparametersConfig, JaggedLRConfig
|
||||||
from axolotl.utils.schemas.trl import TRLConfig
|
from axolotl.utils.schemas.trl import TRLConfig
|
||||||
from axolotl.utils.schemas.validation import ValidationMixin
|
from axolotl.utils.schemas.validation import ValidationMixin
|
||||||
from axolotl.utils.schemas.vllm import VllmConfig
|
from axolotl.utils.schemas.vllm import VllmConfig
|
||||||
@@ -57,6 +57,7 @@ class AxolotlInputConfig(
|
|||||||
ModelOutputConfig,
|
ModelOutputConfig,
|
||||||
LoraConfig,
|
LoraConfig,
|
||||||
ReLoRAConfig,
|
ReLoRAConfig,
|
||||||
|
JaggedLRConfig,
|
||||||
HyperparametersConfig,
|
HyperparametersConfig,
|
||||||
WandbConfig,
|
WandbConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
|
|||||||
@@ -187,18 +187,10 @@ class LoraConfig(BaseModel):
|
|||||||
class ReLoRAConfig(BaseModel):
|
class ReLoRAConfig(BaseModel):
|
||||||
"""ReLoRA configuration subset"""
|
"""ReLoRA configuration subset"""
|
||||||
|
|
||||||
relora_steps: int | None = Field(
|
relora: bool | None = Field(
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "Number of steps per ReLoRA restart"},
|
|
||||||
)
|
|
||||||
relora_warmup_steps: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "Number of per-restart warmup steps"},
|
|
||||||
)
|
|
||||||
relora_anneal_steps: int | None = Field(
|
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Number of anneal steps for each relora cycle"
|
"description": "Whether to use ReLoRA. Use with jagged_restart_*steps options."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
relora_prune_ratio: float | None = Field(
|
relora_prune_ratio: float | None = Field(
|
||||||
|
|||||||
@@ -160,3 +160,24 @@ class HyperparametersConfig(BaseModel):
|
|||||||
if learning_rate and isinstance(learning_rate, str):
|
if learning_rate and isinstance(learning_rate, str):
|
||||||
learning_rate = float(learning_rate)
|
learning_rate = float(learning_rate)
|
||||||
return learning_rate
|
return learning_rate
|
||||||
|
|
||||||
|
|
||||||
|
class JaggedLRConfig(BaseModel):
|
||||||
|
"""JaggedLR configuration subset, can be used w/ ReLoRA training"""
|
||||||
|
|
||||||
|
jagged_restart_steps: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "how often to reset for jagged restarts"},
|
||||||
|
)
|
||||||
|
jagged_restart_warmup_steps: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "how many warmup steps to take after reset for jagged restarts"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
jagged_restart_anneal_steps: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "how many anneal steps to take before reset for jagged restarts"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
@@ -1164,7 +1164,9 @@ class ComplexValidationMixin:
|
|||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_relora(self):
|
def check_relora(self):
|
||||||
if self.relora_steps:
|
if self.relora:
|
||||||
|
if not self.jagged_restart_steps:
|
||||||
|
raise ValueError("jagged_restart_steps must be set to use ReLoRA")
|
||||||
if self.adapter not in ("lora", "qlora"):
|
if self.adapter not in ("lora", "qlora"):
|
||||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||||
|
|
||||||
|
|||||||
@@ -34,9 +34,10 @@ class TestReLoraLlama(unittest.TestCase):
|
|||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_modules": ["q_proj", "v_proj"],
|
"lora_target_modules": ["q_proj", "v_proj"],
|
||||||
"relora_steps": 50,
|
"relora": True,
|
||||||
"relora_warmup_steps": 10,
|
"jagged_restart_steps": 50,
|
||||||
"relora_anneal_steps": 10,
|
"jagged_restart_warmup_steps": 10,
|
||||||
|
"jagged_restart_anneal_steps": 10,
|
||||||
"relora_prune_ratio": 0.9,
|
"relora_prune_ratio": 0.9,
|
||||||
"relora_cpu_offload": True,
|
"relora_cpu_offload": True,
|
||||||
"val_set_size": 0.0,
|
"val_set_size": 0.0,
|
||||||
|
|||||||
Reference in New Issue
Block a user