Compare commits
1 Commits
main
...
grouped_lr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dd7f087b3 |
29
docs/lr_groups.qmd
Normal file
29
docs/lr_groups.qmd
Normal file
@@ -0,0 +1,29 @@
|
||||
---
|
||||
title: Learning Rate Groups
|
||||
description: "Setting different learning rates by module name"
|
||||
---
|
||||
|
||||
## Background
|
||||
|
||||
Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of
|
||||
modules in a model.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
lr_groups:
|
||||
- name: o_proj
|
||||
modules:
|
||||
- self_attn.o_proj.weight
|
||||
lr: 1e-6
|
||||
- name: q_proj
|
||||
modules:
|
||||
- model.layers.2.self_attn.q_proj.weight
|
||||
lr: 1e-5
|
||||
|
||||
learning_rate: 2e-5
|
||||
```
|
||||
|
||||
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
|
||||
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
|
||||
self attention `q_proj` module.
|
||||
@@ -244,6 +244,10 @@ class AxolotlTrainingMixins:
|
||||
default=None,
|
||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||
)
|
||||
lr_groups: Optional[list[dict]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Specify learning rate groups for with different LRs."},
|
||||
)
|
||||
embedding_lr: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
@@ -462,11 +466,96 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
params = {
|
||||
"to_weight_decay": {}, # LayerNorm and bias
|
||||
"embeddings": {}, # lm_head, embed_tokens,
|
||||
"no_weight_decay": {},
|
||||
}
|
||||
lr_groups_lookup = {}
|
||||
lr_groups_learning_rates = {}
|
||||
if self.args.lr_groups:
|
||||
for lr_group in self.args.lr_groups:
|
||||
group_name = lr_group["name"]
|
||||
group_modules = lr_group["modules"]
|
||||
for module in group_modules:
|
||||
lr_groups_lookup[module] = group_name
|
||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||
params[f"to_weight_decay_{group_name}"] = {}
|
||||
|
||||
for name, param in opt_model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if name.endswith("modules_to_save.default.weight") or any(
|
||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||
):
|
||||
params["embeddings"][name] = param
|
||||
elif name in decay_parameters:
|
||||
if lr_groups_lookup and any(
|
||||
group_modules in name for group_modules in lr_groups_lookup
|
||||
):
|
||||
lr_group_module = [
|
||||
group_modules
|
||||
for group_modules in lr_groups_lookup
|
||||
if group_modules in name
|
||||
][0]
|
||||
group_name = lr_groups_lookup[lr_group_module]
|
||||
params[f"to_weight_decay_{group_name}"][name] = param
|
||||
else:
|
||||
params["to_weight_decay"][name] = param
|
||||
else:
|
||||
params["no_weight_decay"][name] = param
|
||||
optimizer_grouped_parameters = []
|
||||
if params["to_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["to_weight_decay"].values()),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
if params["embeddings"]:
|
||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||
if self.args.embedding_lr_scale:
|
||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||
elif self.args.embedding_lr:
|
||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["embeddings"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": lr,
|
||||
}
|
||||
)
|
||||
if params["no_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["no_weight_decay"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||
if params[f"to_weight_decay_{group_name}"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(
|
||||
params[f"to_weight_decay_{group_name}"].values()
|
||||
),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": group_lr,
|
||||
}
|
||||
)
|
||||
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def create_optimizer(self):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.embedding_lr_scale is None
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.lr_groups is None
|
||||
and self.args.alternate_optimizer
|
||||
not in [
|
||||
"optimi_adamw",
|
||||
@@ -480,59 +569,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
params = {
|
||||
"to_weight_decay": {}, # LayerNorm and bias
|
||||
"embeddings": {}, # lm_head, embed_tokens,
|
||||
"no_weight_decay": {},
|
||||
}
|
||||
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
)
|
||||
|
||||
for name, param in opt_model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if name.endswith("modules_to_save.default.weight") or any(
|
||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||
):
|
||||
params["embeddings"][name] = param
|
||||
elif name in decay_parameters:
|
||||
params["to_weight_decay"][name] = param
|
||||
else:
|
||||
params["no_weight_decay"][name] = param
|
||||
optimizer_grouped_parameters = []
|
||||
if params["to_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["to_weight_decay"].values()),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
if params["embeddings"]:
|
||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||
if self.args.embedding_lr_scale:
|
||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||
elif self.args.embedding_lr:
|
||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["embeddings"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": lr,
|
||||
}
|
||||
)
|
||||
if params["no_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["no_weight_decay"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||
opt_model, optimizer_kwargs
|
||||
)
|
||||
|
||||
if self.args.loraplus_lr_ratio is not None:
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
@@ -549,6 +592,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
elif (
|
||||
self.args.embedding_lr_scale is not None
|
||||
or self.args.embedding_lr is not None
|
||||
or self.args.lr_groups is not None
|
||||
):
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
@@ -1764,6 +1808,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
] = self.cfg.loraplus_lr_embedding
|
||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||
|
||||
@@ -145,6 +145,14 @@ class UserDefinedPrompterType(BaseModel):
|
||||
field: Optional[str] = None
|
||||
|
||||
|
||||
class LrGroup(BaseModel):
|
||||
"""Custom learning rate group configuration"""
|
||||
|
||||
name: str
|
||||
modules: List[str]
|
||||
lr: float
|
||||
|
||||
|
||||
class SFTDataset(BaseModel):
|
||||
"""SFT configuration subset"""
|
||||
|
||||
@@ -466,6 +474,7 @@ class HyperparametersConfig(BaseModel):
|
||||
cosine_min_lr_ratio: Optional[float] = None
|
||||
cosine_constant_lr_ratio: Optional[float] = None
|
||||
lr_div_factor: Optional[float] = None
|
||||
lr_groups: Optional[List[LrGroup]] = None
|
||||
|
||||
adam_epsilon: Optional[float] = None
|
||||
adam_beta1: Optional[float] = None
|
||||
|
||||
Reference in New Issue
Block a user