add support for 4bit optimizers
This commit is contained in:
29
docs/optimizers.md
Normal file
29
docs/optimizers.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# Optimizers
|
||||
|
||||
Optimizers are an important component when training LLMs. Optimizers are responsible for updating the model's weights (parameters) based on the gradients computed during backpropagation.
|
||||
The goal of an optimizer is to minimize the loss function.
|
||||
|
||||
### Adam/AdamW Optimizers
|
||||
|
||||
```yaml
|
||||
adam_beta1: 0.9
|
||||
adam_beta2: 0.999
|
||||
adam_epsilon: 1e-8
|
||||
weight_decay: 0.0
|
||||
```
|
||||
|
||||
### GaLore Optimizer
|
||||
|
||||
https://huggingface.co/papers/2403.03507
|
||||
|
||||
```yaml
|
||||
optimizer: galore_adamw | galore_adamw_8bit | galore_adafactor
|
||||
optim_args:
|
||||
rank: 128
|
||||
update_proj_gap: 200
|
||||
scale: 0.25
|
||||
proj_type: std
|
||||
optim_target_modules:
|
||||
- mlp
|
||||
- attn
|
||||
```
|
||||
@@ -41,3 +41,5 @@ gcsfs
|
||||
|
||||
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
|
||||
fastcore>=1.5.29
|
||||
|
||||
lpmm @ git+https://github.com/thu-ml/low-bit-optimizers.git@main
|
||||
|
||||
@@ -15,8 +15,9 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Type, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
import lpmm
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate import FullyShardedDataParallelPlugin
|
||||
@@ -27,6 +28,7 @@ from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import (
|
||||
EarlyStoppingCallback,
|
||||
PreTrainedModel,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
@@ -36,6 +38,7 @@ from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
|
||||
from axolotl.core.trainers import OptimizerNames
|
||||
from axolotl.loraplus import create_loraplus_optimizer
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
@@ -62,6 +65,9 @@ from axolotl.utils.schedulers import (
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
)
|
||||
|
||||
# monkeypatch so it accepts our custom optimizers
|
||||
transformers.training_args.OptimizerNames = OptimizerNames
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
@@ -231,6 +237,38 @@ class AxolotlTrainer(Trainer):
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
@staticmethod
|
||||
def get_optimizer_cls_and_kwargs(
|
||||
args: TrainingArguments, model: Optional[PreTrainedModel] = None
|
||||
) -> Tuple[Any, Any]:
|
||||
optim_args = {}
|
||||
if args.optim_args:
|
||||
for mapping in args.optim_args.replace(" ", "").split(","):
|
||||
key, value = mapping.split("=")
|
||||
optim_args[key] = value
|
||||
|
||||
optimizer_kwargs = {"lr": args.learning_rate}
|
||||
|
||||
adam_kwargs = {
|
||||
"betas": (args.adam_beta1, args.adam_beta2),
|
||||
"eps": args.adam_epsilon,
|
||||
}
|
||||
|
||||
if args.optim in [
|
||||
OptimizerNames.LPMM_ADAMW_4BIT,
|
||||
OptimizerNames.LPMM_ADAMW_4BIT_FUSED,
|
||||
]:
|
||||
optimizer_cls = lpmm.optim.AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
if args.optim == OptimizerNames.LPMM_ADAMW_4BIT_FUSED:
|
||||
optimizer_kwargs.update({"fused": True})
|
||||
return optimizer_cls, optimizer_kwargs
|
||||
|
||||
return Trainer.get_optimizer_cls_and_kwargs(
|
||||
args,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def create_optimizer(self):
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
return super().create_optimizer()
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""module for trainer helpers like OptimizerNames"""
|
||||
|
||||
from transformers.utils import ExplicitEnum
|
||||
|
||||
|
||||
class OptimizerNames(ExplicitEnum):
|
||||
"""
|
||||
Stores the acceptable string identifiers for optimizers.
|
||||
"""
|
||||
|
||||
ADAMW_HF = "adamw_hf"
|
||||
ADAMW_TORCH = "adamw_torch"
|
||||
ADAMW_TORCH_FUSED = "adamw_torch_fused"
|
||||
ADAMW_TORCH_XLA = "adamw_torch_xla"
|
||||
ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
|
||||
ADAMW_APEX_FUSED = "adamw_apex_fused"
|
||||
ADAFACTOR = "adafactor"
|
||||
ADAMW_ANYPRECISION = "adamw_anyprecision"
|
||||
SGD = "sgd"
|
||||
ADAGRAD = "adagrad"
|
||||
ADAMW_BNB = "adamw_bnb_8bit"
|
||||
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
|
||||
LION_8BIT = "lion_8bit"
|
||||
LION = "lion_32bit"
|
||||
PAGED_ADAMW = "paged_adamw_32bit"
|
||||
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
|
||||
PAGED_LION = "paged_lion_32bit"
|
||||
PAGED_LION_8BIT = "paged_lion_8bit"
|
||||
RMSPROP = "rmsprop"
|
||||
RMSPROP_BNB = "rmsprop_bnb"
|
||||
RMSPROP_8BIT = "rmsprop_bnb_8bit"
|
||||
RMSPROP_32BIT = "rmsprop_bnb_32bit"
|
||||
GALORE_ADAMW = "galore_adamw"
|
||||
GALORE_ADAMW_8BIT = "galore_adamw_8bit"
|
||||
GALORE_ADAFACTOR = "galore_adafactor"
|
||||
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
|
||||
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
|
||||
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
|
||||
LPMM_ADAMW_4BIT = "lmpp_adamw_4bit"
|
||||
LPMM_ADAMW_4BIT_FUSED = "lmpp_adamw_4bit_fused"
|
||||
|
||||
Reference in New Issue
Block a user