From d8787a433f7a5f2c4db985f8551246a5f95a6e71 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 29 Nov 2024 20:38:20 -0500 Subject: [PATCH] support seperate lr for embeddings, similar to loraplus (#1910) [skip ci] * support seperate lr for embeddings, similar to loraplus * add test case for train w lr embedding scale * use kwarg for optimizer * make sure to handle the optimizer creation * make sure to handle for embedding_lr too * use smollm for e2e, check for embeddings lr first before wdecay --- src/axolotl/core/trainer_builder.py | 87 ++++++++++--- .../config/models/input/v0_4_1/__init__.py | 2 + tests/e2e/test_embeddings_lr.py | 121 ++++++++++++++++++ 3 files changed, 191 insertions(+), 19 deletions(-) create mode 100644 tests/e2e/test_embeddings_lr.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 9e906179b..8e85ad8ba 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -220,6 +220,14 @@ class AxolotlTrainingMixins: default=1e-6, metadata={"help": "loraplus learning rate for lora embedding layers."}, ) + embedding_lr_scale: Optional[float] = field( + default=None, + metadata={"help": "Scale the learning rate for the embedding layers."}, + ) + embedding_lr: Optional[float] = field( + default=None, + metadata={"help": "absolute learning rate for the embedding layers."}, + ) qlora: bool = field( default=False, metadata={"help": "whether this is a qlora training"}, @@ -386,7 +394,7 @@ class SchedulerMixin(Trainer): min_lr_ratio=self.args.cosine_min_lr_ratio, ) else: - return super().create_scheduler(num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer=optimizer) else: if use_cosine_quadratic: LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") @@ -435,6 +443,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer): def create_optimizer(self): if ( self.args.loraplus_lr_ratio is None + and self.args.embedding_lr_scale is None + and self.args.embedding_lr is None and self.args.alternate_optimizer not in [ "optimi_adamw", @@ -449,30 +459,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer): opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: # pylint: disable=access-member-before-definition decay_parameters = self.get_decay_parameter_names(opt_model) - optimizer_grouped_parameters = [ - { - "params": [ - p - for n, p in opt_model.named_parameters() - if (n in decay_parameters and p.requires_grad) - ], - "weight_decay": self.args.weight_decay, - }, - { - "params": [ - p - for n, p in opt_model.named_parameters() - if (n not in decay_parameters and p.requires_grad) - ], - "weight_decay": 0.0, - }, - ] + params = { + "to_weight_decay": {}, # LayerNorm and bias + "embeddings": {}, # lm_head, embed_tokens, + "no_weight_decay": {}, + } optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( self.args, opt_model, ) + for name, param in opt_model.named_parameters(): + if not param.requires_grad: + continue + if name.endswith("modules_to_save.default.weight") or any( + embed_name in name for embed_name in ["embed_tokens", "lm_head"] + ): + params["embeddings"][name] = param + elif name in decay_parameters: + params["to_weight_decay"][name] = param + else: + params["no_weight_decay"][name] = param + optimizer_grouped_parameters = [] + if params["to_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["to_weight_decay"].values()), + "weight_decay": self.args.weight_decay, + "lr": optimizer_kwargs["lr"], + } + ) + if params["embeddings"]: + lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name + if self.args.embedding_lr_scale: + lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name + elif self.args.embedding_lr: + lr = self.args.embedding_lr # pylint: disable=invalid-name + optimizer_grouped_parameters.append( + { + "params": list(params["embeddings"].values()), + "weight_decay": 0.0, + "lr": lr, + } + ) + if params["no_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["no_weight_decay"].values()), + "weight_decay": 0.0, + "lr": optimizer_kwargs["lr"], + } + ) + if self.args.loraplus_lr_ratio is not None: loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) loraplus_lr_embedding = getattr( @@ -485,6 +524,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer): loraplus_lr_embedding=loraplus_lr_embedding, **optimizer_kwargs, ) + elif ( + self.args.embedding_lr_scale is not None + or self.args.embedding_lr is not None + ): + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + ) elif self.args.alternate_optimizer == "optimi_adamw": from optimi import AdamW @@ -1571,6 +1617,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "loraplus_lr_embedding" ] = self.cfg.loraplus_lr_embedding + training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr + training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale + if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs[ diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index c5b319919..c7d5848f9 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -431,6 +431,8 @@ class HyperparametersConfig(BaseModel): group_by_length: Optional[bool] = None learning_rate: Union[str, float] + embedding_lr: Optional[float] = None + embedding_lr_scale: Optional[float] = None weight_decay: Optional[float] = 0.0 optimizer: Optional[ Union[ diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py new file mode 100644 index 000000000..bc406caf3 --- /dev/null +++ b/tests/e2e/test_embeddings_lr.py @@ -0,0 +1,121 @@ +""" +E2E tests for llama pretrain +""" + +import logging +import os +import unittest +from pathlib import Path + +from tbparse import SummaryReader + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import most_recent_subdir, with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestEmbeddingsLrScale(unittest.TestCase): + """ + Test case for embedding_lr* + """ + + @with_temp_dir + def test_train_w_embedding_lr_scale(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "flash_attention": True, + "sequence_len": 1024, + "sample_packing": True, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "max_steps": 5, + "num_epochs": 1, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "embedding_lr_scale": 0.5, + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + "use_tensorboard": True, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0]) + reader = SummaryReader(event_file) + df = reader.scalars # pylint: disable=invalid-name + df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name + assert df.value.values[-1] < 2.0, "Loss is too high" + + @with_temp_dir + def test_train_w_embedding_lr(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "flash_attention": True, + "sequence_len": 1024, + "sample_packing": True, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "max_steps": 5, + "num_epochs": 1, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "embedding_lr": 0.000005, + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + "use_tensorboard": True, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0]) + reader = SummaryReader(event_file) + df = reader.scalars # pylint: disable=invalid-name + df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name + assert df.value.values[-1] < 2.0, "Loss is too high"