Compare commits
1 Commits
optimizer-
...
grouped_lr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dd7f087b3 |
29
docs/lr_groups.qmd
Normal file
29
docs/lr_groups.qmd
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
---
|
||||||
|
title: Learning Rate Groups
|
||||||
|
description: "Setting different learning rates by module name"
|
||||||
|
---
|
||||||
|
|
||||||
|
## Background
|
||||||
|
|
||||||
|
Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of
|
||||||
|
modules in a model.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
lr_groups:
|
||||||
|
- name: o_proj
|
||||||
|
modules:
|
||||||
|
- self_attn.o_proj.weight
|
||||||
|
lr: 1e-6
|
||||||
|
- name: q_proj
|
||||||
|
modules:
|
||||||
|
- model.layers.2.self_attn.q_proj.weight
|
||||||
|
lr: 1e-5
|
||||||
|
|
||||||
|
learning_rate: 2e-5
|
||||||
|
```
|
||||||
|
|
||||||
|
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
|
||||||
|
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
|
||||||
|
self attention `q_proj` module.
|
||||||
@@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
|
|||||||
torchao==0.7.0
|
torchao==0.7.0
|
||||||
schedulefree==1.3.0
|
schedulefree==1.3.0
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.2
|
axolotl-contribs-lgpl==0.0.1b2
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ def evaluate(config: str, accelerate: bool, **kwargs):
|
|||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
@click.option(
|
@click.option(
|
||||||
"--accelerate/--no-accelerate",
|
"--accelerate/--no-accelerate",
|
||||||
default=False,
|
default=True,
|
||||||
help="Use accelerate launch for multi-GPU inference",
|
help="Use accelerate launch for multi-GPU inference",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
@@ -124,7 +124,7 @@ def inference(
|
|||||||
if lora_model_dir:
|
if lora_model_dir:
|
||||||
kwargs["lora_model_dir"] = lora_model_dir
|
kwargs["lora_model_dir"] = lora_model_dir
|
||||||
if base_model:
|
if base_model:
|
||||||
kwargs["base_model"] = base_model
|
kwargs["output_dir"] = base_model
|
||||||
|
|
||||||
if accelerate:
|
if accelerate:
|
||||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ from axolotl.utils.callbacks import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template
|
||||||
from axolotl.utils.collators import (
|
from axolotl.utils.collators import (
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
@@ -244,6 +244,10 @@ class AxolotlTrainingMixins:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||||
)
|
)
|
||||||
|
lr_groups: Optional[list[dict]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Specify learning rate groups for with different LRs."},
|
||||||
|
)
|
||||||
embedding_lr: Optional[float] = field(
|
embedding_lr: Optional[float] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||||
@@ -424,11 +428,6 @@ class SchedulerMixin(Trainer):
|
|||||||
|
|
||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
def _load_optimizer_and_scheduler(self, checkpoint):
|
|
||||||
if not checkpoint and self.args.optimizer_checkpoint is not None:
|
|
||||||
checkpoint = self.args.optimizer_checkpoint
|
|
||||||
return super()._load_optimizer_and_scheduler(checkpoint)
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, Trainer):
|
class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||||
"""
|
"""
|
||||||
@@ -467,35 +466,23 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
)
|
)
|
||||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
|
||||||
if (
|
|
||||||
self.args.loraplus_lr_ratio is None
|
|
||||||
and self.args.embedding_lr_scale is None
|
|
||||||
and self.args.embedding_lr is None
|
|
||||||
and self.args.alternate_optimizer
|
|
||||||
not in [
|
|
||||||
"optimi_adamw",
|
|
||||||
"ao_adamw_8bit",
|
|
||||||
"ao_adamw_4bit",
|
|
||||||
"ao_adamw_fp8",
|
|
||||||
"adopt_adamw",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
params = {
|
params = {
|
||||||
"to_weight_decay": {}, # LayerNorm and bias
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
"embeddings": {}, # lm_head, embed_tokens,
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
"no_weight_decay": {},
|
"no_weight_decay": {},
|
||||||
}
|
}
|
||||||
|
lr_groups_lookup = {}
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
lr_groups_learning_rates = {}
|
||||||
self.args,
|
if self.args.lr_groups:
|
||||||
opt_model,
|
for lr_group in self.args.lr_groups:
|
||||||
)
|
group_name = lr_group["name"]
|
||||||
|
group_modules = lr_group["modules"]
|
||||||
|
for module in group_modules:
|
||||||
|
lr_groups_lookup[module] = group_name
|
||||||
|
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||||
|
params[f"to_weight_decay_{group_name}"] = {}
|
||||||
|
|
||||||
for name, param in opt_model.named_parameters():
|
for name, param in opt_model.named_parameters():
|
||||||
if not param.requires_grad:
|
if not param.requires_grad:
|
||||||
@@ -505,6 +492,17 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
):
|
):
|
||||||
params["embeddings"][name] = param
|
params["embeddings"][name] = param
|
||||||
elif name in decay_parameters:
|
elif name in decay_parameters:
|
||||||
|
if lr_groups_lookup and any(
|
||||||
|
group_modules in name for group_modules in lr_groups_lookup
|
||||||
|
):
|
||||||
|
lr_group_module = [
|
||||||
|
group_modules
|
||||||
|
for group_modules in lr_groups_lookup
|
||||||
|
if group_modules in name
|
||||||
|
][0]
|
||||||
|
group_name = lr_groups_lookup[lr_group_module]
|
||||||
|
params[f"to_weight_decay_{group_name}"][name] = param
|
||||||
|
else:
|
||||||
params["to_weight_decay"][name] = param
|
params["to_weight_decay"][name] = param
|
||||||
else:
|
else:
|
||||||
params["no_weight_decay"][name] = param
|
params["no_weight_decay"][name] = param
|
||||||
@@ -538,6 +536,46 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
"lr": optimizer_kwargs["lr"],
|
"lr": optimizer_kwargs["lr"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||||
|
if params[f"to_weight_decay_{group_name}"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(
|
||||||
|
params[f"to_weight_decay_{group_name}"].values()
|
||||||
|
),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": group_lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
if (
|
||||||
|
self.args.loraplus_lr_ratio is None
|
||||||
|
and self.args.embedding_lr_scale is None
|
||||||
|
and self.args.embedding_lr is None
|
||||||
|
and self.args.lr_groups is None
|
||||||
|
and self.args.alternate_optimizer
|
||||||
|
not in [
|
||||||
|
"optimi_adamw",
|
||||||
|
"ao_adamw_8bit",
|
||||||
|
"ao_adamw_4bit",
|
||||||
|
"ao_adamw_fp8",
|
||||||
|
"adopt_adamw",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args,
|
||||||
|
opt_model,
|
||||||
|
)
|
||||||
|
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||||
|
opt_model, optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.loraplus_lr_ratio is not None:
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
@@ -554,6 +592,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
elif (
|
elif (
|
||||||
self.args.embedding_lr_scale is not None
|
self.args.embedding_lr_scale is not None
|
||||||
or self.args.embedding_lr is not None
|
or self.args.embedding_lr is not None
|
||||||
|
or self.args.lr_groups is not None
|
||||||
):
|
):
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
@@ -1769,10 +1808,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.loraplus_lr_embedding
|
] = self.cfg.loraplus_lr_embedding
|
||||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
if self.cfg.optimizer_checkpoint:
|
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||||
training_arguments_kwargs[
|
|
||||||
"optimizer_checkpoint"
|
|
||||||
] = self.cfg.optimizer_checkpoint
|
|
||||||
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
@@ -1843,8 +1879,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||||
if self.cfg.chat_template:
|
if self.cfg.chat_template:
|
||||||
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
|
training_arguments_kwargs["chat_template"] = get_chat_template(
|
||||||
cfg=self.cfg,
|
self.cfg.chat_template,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
import inspect
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@@ -127,19 +126,6 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.fix_untrained_tokens:
|
if cfg.fix_untrained_tokens:
|
||||||
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
|
||||||
sig = inspect.signature(fix_untrained_tokens)
|
|
||||||
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
|
||||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
|
||||||
cfg.fix_untrained_tokens, list
|
|
||||||
):
|
|
||||||
fix_untrained_tokens(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
train_dataset,
|
|
||||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
|
|||||||
@@ -145,6 +145,14 @@ class UserDefinedPrompterType(BaseModel):
|
|||||||
field: Optional[str] = None
|
field: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LrGroup(BaseModel):
|
||||||
|
"""Custom learning rate group configuration"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
modules: List[str]
|
||||||
|
lr: float
|
||||||
|
|
||||||
|
|
||||||
class SFTDataset(BaseModel):
|
class SFTDataset(BaseModel):
|
||||||
"""SFT configuration subset"""
|
"""SFT configuration subset"""
|
||||||
|
|
||||||
@@ -466,6 +474,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
cosine_min_lr_ratio: Optional[float] = None
|
cosine_min_lr_ratio: Optional[float] = None
|
||||||
cosine_constant_lr_ratio: Optional[float] = None
|
cosine_constant_lr_ratio: Optional[float] = None
|
||||||
lr_div_factor: Optional[float] = None
|
lr_div_factor: Optional[float] = None
|
||||||
|
lr_groups: Optional[List[LrGroup]] = None
|
||||||
|
|
||||||
adam_epsilon: Optional[float] = None
|
adam_epsilon: Optional[float] = None
|
||||||
adam_beta1: Optional[float] = None
|
adam_beta1: Optional[float] = None
|
||||||
@@ -603,8 +612,6 @@ class AxolotlInputConfig(
|
|||||||
strict: Optional[bool] = Field(default=False)
|
strict: Optional[bool] = Field(default=False)
|
||||||
resume_from_checkpoint: Optional[str] = None
|
resume_from_checkpoint: Optional[str] = None
|
||||||
auto_resume_from_checkpoints: Optional[bool] = None
|
auto_resume_from_checkpoints: Optional[bool] = None
|
||||||
optimizer_checkpoint: Optional[str] = None
|
|
||||||
|
|
||||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||||
mean_resizing_embeddings: Optional[bool] = False
|
mean_resizing_embeddings: Optional[bool] = False
|
||||||
|
|
||||||
@@ -796,7 +803,7 @@ class AxolotlInputConfig(
|
|||||||
chat_template_jinja: Optional[str] = None
|
chat_template_jinja: Optional[str] = None
|
||||||
default_system_message: Optional[str] = None
|
default_system_message: Optional[str] = None
|
||||||
|
|
||||||
fix_untrained_tokens: Optional[Union[int, List[int]]] = None
|
fix_untrained_tokens: Optional[bool] = None
|
||||||
|
|
||||||
# INTERNALS - document for now, generally not set externally
|
# INTERNALS - document for now, generally not set externally
|
||||||
is_preprocess: Optional[bool] = None
|
is_preprocess: Optional[bool] = None
|
||||||
|
|||||||
Reference in New Issue
Block a user