From 887513285d98132142bf5db2a74eb5e0928787f1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 24 Jan 2025 12:56:28 -0500 Subject: [PATCH] support for custom lr groups for non-embedding modules (#2213) * support for custom lr groups for non-embedding modules invert name check for group modules include lr_groups in training args additional conditional for creating optimizer fix regular params as w weight decay fix lookup and add docs * address pr feedback --- docs/lr_groups.qmd | 29 ++++ src/axolotl/core/trainer_builder.py | 142 ++++++++++++------ .../config/models/input/v0_4_1/__init__.py | 9 ++ 3 files changed, 131 insertions(+), 49 deletions(-) create mode 100644 docs/lr_groups.qmd diff --git a/docs/lr_groups.qmd b/docs/lr_groups.qmd new file mode 100644 index 000000000..52059016c --- /dev/null +++ b/docs/lr_groups.qmd @@ -0,0 +1,29 @@ +--- +title: Learning Rate Groups +description: "Setting different learning rates by module name" +--- + +## Background + +Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of +modules in a model. + +## Example + +```yaml +lr_groups: + - name: o_proj + modules: + - self_attn.o_proj.weight + lr: 1e-6 + - name: q_proj + modules: + - model.layers.2.self_attn.q_proj.weight + lr: 1e-5 + +learning_rate: 2e-5 +``` + +In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate +of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's +self attention `q_proj` module. diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 62c6a9721..d63a10e74 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -243,6 +243,10 @@ class AxolotlTrainingMixins: default=None, metadata={"help": "Scale the learning rate for the embedding layers."}, ) + lr_groups: Optional[list[dict]] = field( + default=None, + metadata={"help": "Specify learning rate groups for with different LRs."}, + ) embedding_lr: Optional[float] = field( default=None, metadata={"help": "absolute learning rate for the embedding layers."}, @@ -461,11 +465,95 @@ class AxolotlTrainer(SchedulerMixin, Trainer): ) return super()._wrap_model(model, training=training, dataloader=dataloader) + def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs): + decay_parameters = self.get_decay_parameter_names(opt_model) + params = { + "to_weight_decay": {}, # LayerNorm and bias + "embeddings": {}, # lm_head, embed_tokens, + "no_weight_decay": {}, + } + lr_groups_lookup = {} + lr_groups_learning_rates = {} + if self.args.lr_groups: + for lr_group in self.args.lr_groups: + group_name = lr_group["name"] + group_modules = lr_group["modules"] + for module in group_modules: + lr_groups_lookup[module] = group_name + lr_groups_learning_rates[group_name] = lr_group["lr"] + params[f"to_weight_decay_{group_name}"] = {} + + for name, param in opt_model.named_parameters(): + if not param.requires_grad: + continue + if name.endswith("modules_to_save.default.weight") or any( + embed_name in name for embed_name in ["embed_tokens", "lm_head"] + ): + params["embeddings"][name] = param + elif name in decay_parameters: + lr_group_modules = [ + group_modules + for group_modules in lr_groups_lookup + if group_modules in name + ] + if lr_groups_lookup and any(lr_group_modules): + lr_group_module = lr_group_modules[0] + group_name = lr_groups_lookup[lr_group_module] + params[f"to_weight_decay_{group_name}"][name] = param + else: + params["to_weight_decay"][name] = param + else: + params["no_weight_decay"][name] = param + optimizer_grouped_parameters = [] + if params["to_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["to_weight_decay"].values()), + "weight_decay": self.args.weight_decay, + "lr": optimizer_kwargs["lr"], + } + ) + if params["embeddings"]: + lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name + if self.args.embedding_lr_scale: + lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name + elif self.args.embedding_lr: + lr = self.args.embedding_lr # pylint: disable=invalid-name + optimizer_grouped_parameters.append( + { + "params": list(params["embeddings"].values()), + "weight_decay": 0.0, + "lr": lr, + } + ) + if params["no_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["no_weight_decay"].values()), + "weight_decay": 0.0, + "lr": optimizer_kwargs["lr"], + } + ) + for group_name, group_lr in lr_groups_learning_rates.items(): + if params[f"to_weight_decay_{group_name}"]: + optimizer_grouped_parameters.append( + { + "params": list( + params[f"to_weight_decay_{group_name}"].values() + ), + "weight_decay": self.args.weight_decay, + "lr": group_lr, + } + ) + + return optimizer_grouped_parameters + def create_optimizer(self): if ( self.args.loraplus_lr_ratio is None and self.args.embedding_lr_scale is None and self.args.embedding_lr is None + and self.args.lr_groups is None and self.args.alternate_optimizer not in [ "optimi_adamw", @@ -479,59 +567,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer): opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: # pylint: disable=access-member-before-definition - decay_parameters = self.get_decay_parameter_names(opt_model) - params = { - "to_weight_decay": {}, # LayerNorm and bias - "embeddings": {}, # lm_head, embed_tokens, - "no_weight_decay": {}, - } - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( self.args, opt_model, ) - - for name, param in opt_model.named_parameters(): - if not param.requires_grad: - continue - if name.endswith("modules_to_save.default.weight") or any( - embed_name in name for embed_name in ["embed_tokens", "lm_head"] - ): - params["embeddings"][name] = param - elif name in decay_parameters: - params["to_weight_decay"][name] = param - else: - params["no_weight_decay"][name] = param - optimizer_grouped_parameters = [] - if params["to_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["to_weight_decay"].values()), - "weight_decay": self.args.weight_decay, - "lr": optimizer_kwargs["lr"], - } - ) - if params["embeddings"]: - lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name - if self.args.embedding_lr_scale: - lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name - elif self.args.embedding_lr: - lr = self.args.embedding_lr # pylint: disable=invalid-name - optimizer_grouped_parameters.append( - { - "params": list(params["embeddings"].values()), - "weight_decay": 0.0, - "lr": lr, - } - ) - if params["no_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["no_weight_decay"].values()), - "weight_decay": 0.0, - "lr": optimizer_kwargs["lr"], - } - ) + optimizer_grouped_parameters = self.create_optimizer_grouped_parameters( + opt_model, optimizer_kwargs + ) if self.args.loraplus_lr_ratio is not None: loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) @@ -548,6 +590,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer): elif ( self.args.embedding_lr_scale is not None or self.args.embedding_lr is not None + or self.args.lr_groups is not None ): self.optimizer = ( # pylint: disable=attribute-defined-outside-init optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) @@ -1665,6 +1708,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ] = self.cfg.loraplus_lr_embedding training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale + training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: training_arguments_kwargs["lr_scheduler_type"] = "cosine" diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 98cdee009..44e247886 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -147,6 +147,14 @@ class UserDefinedPrompterType(BaseModel): field: Optional[str] = None +class LrGroup(BaseModel): + """Custom learning rate group configuration""" + + name: str + modules: List[str] + lr: float + + class SFTDataset(BaseModel): """SFT configuration subset""" @@ -475,6 +483,7 @@ class HyperparametersConfig(BaseModel): cosine_min_lr_ratio: Optional[float] = None cosine_constant_lr_ratio: Optional[float] = None lr_div_factor: Optional[float] = None + lr_groups: Optional[List[LrGroup]] = None adam_epsilon: Optional[float] = None adam_beta1: Optional[float] = None