diff --git a/docs/optimizers.md b/docs/optimizers.md new file mode 100644 index 000000000..8d45d5f2f --- /dev/null +++ b/docs/optimizers.md @@ -0,0 +1,29 @@ +# Optimizers + +Optimizers are an important component when training LLMs. Optimizers are responsible for updating the model's weights (parameters) based on the gradients computed during backpropagation. +The goal of an optimizer is to minimize the loss function. + +### Adam/AdamW Optimizers + +```yaml +adam_beta1: 0.9 +adam_beta2: 0.999 +adam_epsilon: 1e-8 +weight_decay: 0.0 +``` + +### GaLore Optimizer + +https://huggingface.co/papers/2403.03507 + +```yaml +optimizer: galore_adamw | galore_adamw_8bit | galore_adafactor +optim_args: + rank: 128 + update_proj_gap: 200 + scale: 0.25 + proj_type: std +optim_target_modules: + - mlp + - attn +``` diff --git a/requirements.txt b/requirements.txt index aaa27c547..6a5491d38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,3 +41,5 @@ gcsfs trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90 fastcore>=1.5.29 + +lpmm @ git+https://github.com/thu-ml/low-bit-optimizers.git@main diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b23a8a124..d52ac7b07 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -15,8 +15,9 @@ from collections import defaultdict from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import Dict, List, Literal, Optional, Type, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union +import lpmm import torch import transformers from accelerate import FullyShardedDataParallelPlugin @@ -27,6 +28,7 @@ from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( EarlyStoppingCallback, + PreTrainedModel, Trainer, TrainerCallback, TrainingArguments, @@ -36,6 +38,7 @@ from transformers.utils import is_sagemaker_mp_enabled from trl import DPOTrainer from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory +from axolotl.core.trainers import OptimizerNames from axolotl.loraplus import create_loraplus_optimizer from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler @@ -62,6 +65,9 @@ from axolotl.utils.schedulers import ( get_cosine_schedule_with_warmup_decay_constant, ) +# monkeypatch so it accepts our custom optimizers +transformers.training_args.OptimizerNames = OptimizerNames + if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp @@ -231,6 +237,38 @@ class AxolotlTrainer(Trainer): if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + @staticmethod + def get_optimizer_cls_and_kwargs( + args: TrainingArguments, model: Optional[PreTrainedModel] = None + ) -> Tuple[Any, Any]: + optim_args = {} + if args.optim_args: + for mapping in args.optim_args.replace(" ", "").split(","): + key, value = mapping.split("=") + optim_args[key] = value + + optimizer_kwargs = {"lr": args.learning_rate} + + adam_kwargs = { + "betas": (args.adam_beta1, args.adam_beta2), + "eps": args.adam_epsilon, + } + + if args.optim in [ + OptimizerNames.LPMM_ADAMW_4BIT, + OptimizerNames.LPMM_ADAMW_4BIT_FUSED, + ]: + optimizer_cls = lpmm.optim.AdamW + optimizer_kwargs.update(adam_kwargs) + if args.optim == OptimizerNames.LPMM_ADAMW_4BIT_FUSED: + optimizer_kwargs.update({"fused": True}) + return optimizer_cls, optimizer_kwargs + + return Trainer.get_optimizer_cls_and_kwargs( + args, + model=model, + ) + def create_optimizer(self): if self.args.loraplus_lr_ratio is None: return super().create_optimizer() diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py index e69de29bb..504990fec 100644 --- a/src/axolotl/core/trainers/__init__.py +++ b/src/axolotl/core/trainers/__init__.py @@ -0,0 +1,40 @@ +"""module for trainer helpers like OptimizerNames""" + +from transformers.utils import ExplicitEnum + + +class OptimizerNames(ExplicitEnum): + """ + Stores the acceptable string identifiers for optimizers. + """ + + ADAMW_HF = "adamw_hf" + ADAMW_TORCH = "adamw_torch" + ADAMW_TORCH_FUSED = "adamw_torch_fused" + ADAMW_TORCH_XLA = "adamw_torch_xla" + ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused" + ADAMW_APEX_FUSED = "adamw_apex_fused" + ADAFACTOR = "adafactor" + ADAMW_ANYPRECISION = "adamw_anyprecision" + SGD = "sgd" + ADAGRAD = "adagrad" + ADAMW_BNB = "adamw_bnb_8bit" + ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit + LION_8BIT = "lion_8bit" + LION = "lion_32bit" + PAGED_ADAMW = "paged_adamw_32bit" + PAGED_ADAMW_8BIT = "paged_adamw_8bit" + PAGED_LION = "paged_lion_32bit" + PAGED_LION_8BIT = "paged_lion_8bit" + RMSPROP = "rmsprop" + RMSPROP_BNB = "rmsprop_bnb" + RMSPROP_8BIT = "rmsprop_bnb_8bit" + RMSPROP_32BIT = "rmsprop_bnb_32bit" + GALORE_ADAMW = "galore_adamw" + GALORE_ADAMW_8BIT = "galore_adamw_8bit" + GALORE_ADAFACTOR = "galore_adafactor" + GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise" + GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise" + GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise" + LPMM_ADAMW_4BIT = "lmpp_adamw_4bit" + LPMM_ADAMW_4BIT_FUSED = "lmpp_adamw_4bit_fused"